public abstract class SharedTreeModelBuilder<TM extends DTree.TreeModel> extends Job.ValidatedJob
Used for both Gradient Boosted Method (see GBM
) and Random
Forest (see DRF
), and really could be used for any decision-tree builder.
While this is a wholly H2O-design, we found these papers afterwards that describes our design fairly well:
Note that our dynamic histogram technique is different (surely faster, and probably less mathematically clean). I'm sure a host of other smaller details differ also - but in the Big Picture the paper and our algorithm are similar.
Modifier and Type | Class and Description |
---|---|
class |
SharedTreeModelBuilder.Score |
class |
SharedTreeModelBuilder.ScoreBuildHistogram |
Job.ValidatedJob.Response2CMAdaptor
Job.ChunkProgress, Job.ChunkProgressJob, Job.ColumnsJob, Job.ColumnsResJob, Job.Fail, Job.FrameJob, Job.JobCancelledException, Job.JobHandle, Job.JobState, Job.List, Job.ModelJob, Job.ModelJobWithoutClassificationField, Job.Progress, Job.ProgressMonitor, Job.ValidatedJob
Request2.ColumnSelect, Request2.Dependent, Request2.DoClassBoolean, Request2.DRFCopyDataBoolean, Request2.MultiVecSelect, Request2.MultiVecSelectType, Request2.SpecialVecSelect, Request2.TypeaheadKey, Request2.VecClassSelect, Request2.VecSelect
Request.API, Request.Default, Request.Filter, Request.Validator<V>
RequestBuilders.ArrayBuilder, RequestBuilders.ArrayHeaderRowBuilder, RequestBuilders.ArrayRowBuilder, RequestBuilders.ArrayRowElementBuilder, RequestBuilders.ArrayRowSingleColBuilder, RequestBuilders.BooleanStringBuilder, RequestBuilders.Builder, RequestBuilders.ElementBuilder, RequestBuilders.HideBuilder, RequestBuilders.KeyCellBuilder, RequestBuilders.KeyElementBuilder, RequestBuilders.KeyLinkElementBuilder, RequestBuilders.KeyMinAvgMaxBuilder, RequestBuilders.NoCaptionObjectBuilder, RequestBuilders.ObjectBuilder, RequestBuilders.PaginatedTable, RequestBuilders.PreFormattedBuilder, RequestBuilders.Response, RequestBuilders.ResponseInfo, RequestBuilders.WarningCellBuilder
RequestArguments.Argument<T>, RequestArguments.Bool, RequestArguments.ClassifyBool, RequestArguments.DRFCopyDataBool, RequestArguments.EnumArgument<T extends java.lang.Enum<T>>, RequestArguments.ExistingFile, RequestArguments.FrameClassVec, RequestArguments.FrameKeyMultiVec, RequestArguments.FrameKeyVec, RequestArguments.GeneralFile, RequestArguments.H2OExistingKey, RequestArguments.H2OIllegalArgumentException, RequestArguments.H2OKey, RequestArguments.H2OKey2, RequestArguments.InputCheckBox, RequestArguments.InputSelect<T>, RequestArguments.InputText<T>, RequestArguments.Int, RequestArguments.LongInt, RequestArguments.MultipleSelect<T>, RequestArguments.MultipleText<T>, RequestArguments.NumberSequence, RequestArguments.NumberSequenceFloat, RequestArguments.Real, RequestArguments.Record<T>, RequestArguments.RSeq, RequestArguments.RSeqFloat, RequestArguments.Str, RequestArguments.StringList, RequestArguments.TypeaheadInputText<T>
RequestStatics.RequestType
Constants.Extensions, Constants.Schemes, Constants.Suffixes
Modifier and Type | Field and Description |
---|---|
protected long[] |
_distribution |
protected float[] |
_modelClassDist |
protected int |
_nclass |
protected int |
_ncols |
protected long |
_nrows |
protected int |
_ntreesFromCheckpoint |
protected float[] |
_priorClassDist |
boolean |
balance_classes
For imbalanced data, balance training data class counts via
over/under-sampling.
|
Key |
checkpoint |
float[] |
class_sampling_factors
Desired over/under-sampling ratios per class (lexicographic order).
|
static int |
DECIDED_ROW
Marker for already decided row.
|
static DocGen.FieldDoc[] |
DOC_FIELDS |
protected boolean |
importance |
float |
max_after_balance_size
When classes are balanced, limit the resulting dataset size to the
specified multiple of the original dataset size.
|
int |
max_depth |
static int |
MAX_SUPPORTED_LEVELS
Maximal number of supported levels in response.
|
int |
min_rows |
int |
nbins |
int |
ntrees |
static int |
OUT_OF_BAG
Marker for sampled out rows
|
boolean |
overwrite_checkpoint |
boolean |
score_each_iteration |
_cmDomain, _cv_count, _names, _responseName, _sourceResponseDomain, _train, _valid, _validResponse, _validResponseDomain, holdout_fraction, keep_cross_validation_splits, n_folds, validation, xval_models
classification
response
cols, ignored_cols, ignored_cols_by_name
source
_cv, _fjtask, description, destination_key, end_time, exception, job_key, LIST, start_time, state
_parms, response_info
_requestHelp, SUPPORTS_ONLY_V1, SUPPORTS_ONLY_V2, SUPPORTS_V1_V2
ARRAY_BUILDER, ARRAY_HEADER_ROW_BUILDER, ARRAY_ROW_BUILDER, ARRAY_ROW_ELEMENT_BUILDER, ARRAY_ROW_SINGLECOL_BUILDER, ELEMENT_BUILDER, GSON_BUILDER, OBJECT_BUILDER, ROOT_OBJECT
_queryHtml
_arguments
ALPHA, ARGUMENTS, AUC, BASE, BEST_THRESHOLD, BETA_EPS, BIN_LIMIT, BROWSE, BUCKET, BUILT_IN_KEY_JOBS, CANCELLED, CARDINALITY, CASE, CASE_MODE, CHUNK, CLASS, CLOUD_HEALTH, CLOUD_NAME, CLOUD_SIZE, CLOUD_UPTIME_MILLIS, CLUSTERS, COEFFICIENTS, COL_INDEX, COLS, COLUMN_NAME, COLUMNS_DISPLAY, CONSENSUS, CONTENTS, COUNT, DATA_KEY, DEPTH, DESCRIPTION, DEST_KEY, DTHRESHOLDS, ELAPSED, END_TIME, ENUM_DOMAIN_SIZE, ERROR, ESCAPE_NAN, EXCLUSIVE_SPLIT_LIMIT, EXPRESSION, FAILED, FAMILY, FEATURES, FILE, FILES, FILTER, FIRST_CHUNK, FJ_QUEUE_HI, FJ_QUEUE_LO, FJ_THREADS_HI, FJ_THREADS_LO, FREE_DISK, FREE_MEM, GFLOPS, HEADER, HEIGHT, HELP, IGNORE, ITEMS, ITERATIVE_CM, JOB, JOB_KEY, JOBS, JSON_H2O, KEY, KEYS, LAMBDA, LAST_CONTACT, LIMIT, LINK, LOCKED, MAX, MAX_DISK, MAX_ITER, MAX_MEM, MAX_ROWS, MEAN, MEM_BW, MIN, MODEL_KEY, MODELS, MORE, MTRY, MTRY_NODES, NAME, NEG_X, NO_CM, NODE, NODE_HEALTH, NODE_NAME, NODES, NORMALIZE, NUM_COLS, NUM_CPUS, NUM_FAILED, NUM_KEYS, NUM_MISSING_VALUES, NUM_ROWS, NUM_SUCCEEDED, NUM_TREES, OBJECT, OFFSET, OOBEE, PARALLEL, PARSER_TYPE, PATH, PREVIEW, PREVIOUS_MODEL_KEY, PRIOR, PROGRESS, PROGRESS_KEY, PROGRESS_TOTAL, REDIRECT, REDIRECT_ARGS, REPLICATION_FACTOR, REQUEST_TIME, RESPONSE, RHO, ROW, ROW_SIZE, ROWS, RPCS, SAMPLE, SAMPLING_STRATEGY, SCALE, SEED, SENT_ROWS, SEPARATOR, SIZE, SOURCE_KEY, STACK_TRACES, START_TIME, STAT_TYPE, STATUS, STEP, STRATA_SAMPLES, SUCCEEDED, SYSTEM_LOAD, TASK_KEY, TCPS_ACTIVE, TCPS_DUTY, TIME, TO_ENUM, TOT_MEM, TREE_COUNT, TREE_DEPTH, TREE_LEAVES, TREE_NUM, TREES, TWEEDIE_POWER, TYPE, URL, USE_NON_LOCAL_DATA, VALUE, VALUE_SIZE, VALUE_TYPE, VARIANCE, VERSION, VIEW, WARNINGS, WEIGHT, WEIGHTS, WIDTH, X, XVAL, Y
Constructor and Description |
---|
SharedTreeModelBuilder() |
Modifier and Type | Method and Description |
---|---|
protected DHistogram[][][] |
buildLayer(Frame fr,
DTree[] ktrees,
int[] leafs,
DHistogram[][][] hcs,
boolean subset,
boolean build_tree_one_node) |
void |
buildModel(long seed) |
protected abstract TM |
buildModel(TM initialModel,
Frame trainFr,
java.lang.String[] names,
java.lang.String[][] domains,
Timer t_build)
Builds model
|
protected Chunk |
chk_nids(Chunk[] chks,
int t) |
protected Chunk |
chk_oobt(Chunk[] chks) |
protected Chunk |
chk_resp(Chunk[] chks) |
protected Chunk |
chk_tree(Chunk[] chks,
int c) |
protected Chunk |
chk_work(Chunk[] chks,
int c) |
protected void |
cleanUp(Frame fr,
Timer t_build) |
static java.util.Random |
createRNG(long seed) |
protected double[] |
data_row(Chunk[] chks,
int row,
double[] data) |
protected void |
debugPrintTreeColumns(Frame fr) |
protected Key |
defaultDestKey() |
protected TM |
doScoring(TM model,
Frame fTrain,
DTree[] ktrees,
int tid,
DTree.TreeModel.TreeStats tstats,
boolean finalScoring,
boolean oob,
boolean build_tree_one_node) |
protected abstract VarImp |
doVarImpCalc(TM model,
DTree[] ktrees,
int tid,
Frame validationFrame,
boolean scale) |
protected boolean |
inBagRow(Chunk[] chks,
int row) |
protected void |
init()
Invoked before job runs.
|
protected abstract void |
initAlgo(TM initialModel)
Initialize algorithm - e.g., allocate algorithm specific datastructure.
|
protected abstract void |
initWorkFrame(TM initialModel,
Frame fr)
Initialize working frame.
|
protected boolean |
isClassification() |
static boolean |
isDecidedRow(int nid) |
static boolean |
isOOBRow(int nid) |
protected abstract Log.Tag.Sys |
logTag()
Returns a log tag for a particular model builder (e.g., DRF, GBM)
|
protected AUCData |
makeAUC(ConfusionMatrix[] cms,
float[] threshold) |
protected abstract DTree.DecidedNode |
makeDecided(DTree.UndecidedNode udn,
DHistogram[] hs) |
protected abstract TM |
makeModel(Key outputKey,
Key dataKey,
Key testKey,
int ntrees,
java.lang.String[] names,
java.lang.String[][] domains,
java.lang.String[] cmDomain,
float[] priorClassDist,
float[] classDist) |
protected abstract TM |
makeModel(TM model,
double err,
ConfusionMatrix cm,
VarImp varimp,
AUCData validAUC) |
protected abstract TM |
makeModel(TM model,
DTree[] ktrees,
DTree.TreeModel.TreeStats tstats) |
static int |
nid2Oob(int nid) |
static int |
oob2Nid(int oobNid) |
protected static void |
printGenerateTrees(DTree[] trees) |
float |
progress()
Return progress of this job.
|
protected abstract float |
score1(Chunk[] chks,
float[] fs,
int row) |
java.lang.String |
speedDescription()
Description of a speed criteria: msecs/frob
|
long |
speedValue()
Value of the described speed criteria: msecs/frob
|
boolean |
supportsBagging() |
protected abstract TM |
updateModel(TM model,
TM checkpoint,
boolean overwriteCheckpoint) |
protected Vec |
vec_nids(Frame fr,
int t) |
protected Vec |
vec_resp(Frame fr,
int t) |
protected Vec |
vec_tree(Frame fr,
int c) |
crossValidate, cv_progress, genericCrossValidation, getCMDomain, getOrigValidation, getValidAdaptor, getValidation, getVectorDomain, hasValidation, prepareValidationWithModel, queryArgumentValueSet, registered, toJSON
selectFrame, selectVecs
all, cancel, cancel, cancel, checkIdx, defaultJobKey, dest, findJob, findJobByDest, fork, get, getState, gridParallelism, hygiene, hygiene, invoke, isCancelledOrCrashed, isCrashed, isDone, isEnded, isRunning, isRunning, onCancelled, redirect, remove, runTimeMs, self, serve, start, waitUntilJobEnded, waitUntilJobEnded
cleanup, emptyLTrash, exec, execImpl, gtrash, gtrash, ltrash, ltrash
create, fillResponseInfo, filterNaCols, find, input, logStart, makeJsonBox, serveGrid, servePublic, set, split, superServeGrid, supportedVersions, toJSON, toString
addToNavbar, addToNavbar, addToNavbar, DocExampleFail, DocExampleSucc, href, href, hrefType, HTMLHelp, htmlTemplate, initializeNavBar, log, mapTypeahead, ReSTHelp, serve, serveJava, serveResponse, toDocGET, toHTML, toJava, wrap, wrap, wrap, writeJSONFields
build, buildJSONResponseBox, buildResponseHeader, name
buildQuery, checkArguments
arguments, argumentsToJson, frameColumnNameToIndex
checkJsonName, encodeRedirectArgs, JSON2HTML, jsonError, requestName, Str2JSON
clone, frozenType, init, newInstance, read, toDocField, write, writeJSON
public static DocGen.FieldDoc[] DOC_FIELDS
@Request.API(help="Number of trees. Grid Search, comma sep values:50,100,150,200", filter=Request.Default.class, lmin=1L, lmax=1000000L, json=true, importance=CRITICAL) public int ntrees
@Request.API(help="Maximum tree depth. Grid Search, comma sep values:5,7", filter=Request.Default.class, lmin=1L, lmax=10000L, json=true, importance=CRITICAL) public int max_depth
@Request.API(help="Fewest allowed observations in a leaf (in R called \'nodesize\'). Grid Search, comma sep values", filter=Request.Default.class, lmin=1L, json=true, importance=SECONDARY) public int min_rows
@Request.API(help="Build a histogram of this many bins, then split at the best point", filter=Request.Default.class, lmin=2L, lmax=10000L, json=true, importance=SECONDARY) public int nbins
@Request.API(help="Perform scoring after each iteration (can be slow)", filter=Request.Default.class, json=true) public boolean score_each_iteration
@Request.API(help="Compute variable importance (true/false).", filter=Request.Default.class) protected boolean importance
@Request.API(help="Balance training data class counts via over/under-sampling (for imbalanced data)", filter=Request.Default.class, json=true, importance=EXPERT) public boolean balance_classes
@Request.API(help="Desired over/under-sampling ratios per class (lexicographic order).", filter=Request.Default.class, dmin=0.0, json=true, importance=SECONDARY) public float[] class_sampling_factors
@Request.API(help="Maximum relative size of the training data after balancing class counts (can be less than 1.0)", filter=Request.Default.class, json=true, dmin=0.001, importance=EXPERT) public float max_after_balance_size
@Request.API(help="Model checkpoint to start building a new model from", filter=Request.Default.class, json=true, required=false) public Key checkpoint
@Request.API(help="Overwrite checkpoint", filter=Request.Default.class, json=true, required=false) public boolean overwrite_checkpoint
protected int _ncols
protected long _nrows
protected int _nclass
@Request.API(help="Class distribution") protected long[] _distribution
protected float[] _priorClassDist
protected float[] _modelClassDist
protected int _ntreesFromCheckpoint
public static final int MAX_SUPPORTED_LEVELS
public static final int DECIDED_ROW
public static final int OUT_OF_BAG
public float progress()
Job
protected void init()
Job
init
in class Job.ValidatedJob
protected Key defaultDestKey()
defaultDestKey
in class Job
public void buildModel(long seed)
protected TM doScoring(TM model, Frame fTrain, DTree[] ktrees, int tid, DTree.TreeModel.TreeStats tstats, boolean finalScoring, boolean oob, boolean build_tree_one_node)
protected abstract VarImp doVarImpCalc(TM model, DTree[] ktrees, int tid, Frame validationFrame, boolean scale)
public boolean supportsBagging()
protected double[] data_row(Chunk[] chks, int row, double[] data)
protected DHistogram[][][] buildLayer(Frame fr, DTree[] ktrees, int[] leafs, DHistogram[][][] hcs, boolean subset, boolean build_tree_one_node)
protected abstract DTree.DecidedNode makeDecided(DTree.UndecidedNode udn, DHistogram[] hs)
protected abstract float score1(Chunk[] chks, float[] fs, int row)
public java.lang.String speedDescription()
Job
speedDescription
in class Job
public long speedValue()
Job
speedValue
in class Job
protected abstract Log.Tag.Sys logTag()
protected abstract TM buildModel(TM initialModel, Frame trainFr, java.lang.String[] names, java.lang.String[][] domains, Timer t_build)
initialModel
- initial model created by makeModel() method.trainFr
- training dataset which can contain additional temporary vectors prepared by buildModel() method.names
- names of columns in trainFr
used for model trainingdomains
- domains of columns in trainFr
used for model trainingt_build
- timer to measure model building processprotected abstract void initAlgo(TM initialModel)
initialModel
- protected abstract void initWorkFrame(TM initialModel, Frame fr)
initialModel
- initial modelfr
- working frame which contains train data and additional columns prepared by this builder.protected abstract TM makeModel(Key outputKey, Key dataKey, Key testKey, int ntrees, java.lang.String[] names, java.lang.String[][] domains, java.lang.String[] cmDomain, float[] priorClassDist, float[] classDist)
protected abstract TM makeModel(TM model, double err, ConfusionMatrix cm, VarImp varimp, AUCData validAUC)
protected abstract TM makeModel(TM model, DTree[] ktrees, DTree.TreeModel.TreeStats tstats)
protected AUCData makeAUC(ConfusionMatrix[] cms, float[] threshold)
protected boolean inBagRow(Chunk[] chks, int row)
protected final boolean isClassification()
public static final boolean isOOBRow(int nid)
public static final boolean isDecidedRow(int nid)
public static final int oob2Nid(int oobNid)
public static final int nid2Oob(int nid)
public static java.util.Random createRNG(long seed)
protected static void printGenerateTrees(DTree[] trees)
protected final void debugPrintTreeColumns(Frame fr)