public abstract class ModelBuilder<M extends Model<M,P,O>,P extends Model.Parameters,O extends Model.Output> extends Job<M>
Modifier and Type | Class and Description |
---|---|
static class |
ModelBuilder.BuilderVisibility
Visibility for this algo: is it always visible, is it beta (always visible but with a note in the UI)
or is it experimental (hidden by default, visible in the UI if the user gives an "experimental" flag
at startup).
|
Job.JobCancelledException, Job.JobState, Job.Progress, Job.ProgressUpdate, Job.ValidationMessage
Keyed.BinarySerializer<X extends Keyed>
Modifier and Type | Field and Description |
---|---|
protected Vec |
_fold |
protected int |
_nclass |
protected Vec |
_offset |
P |
_parms
All the parameters required to build the model.
|
protected Vec |
_response |
protected Frame |
_train |
protected Frame |
_valid |
protected Vec |
_vresponse |
protected Vec |
_weights |
_description, _dest, _end_time, _exception, _messages, _progressKey, _start_time, _state, LIST
_key, EMPTY_KEY_LIST
Constructor and Description |
---|
ModelBuilder(Key dest,
java.lang.String desc,
P parms)
Default constructor, given all arguments
|
ModelBuilder(P ignore)
Constructor called from an http request; MUST override in subclasses.
|
ModelBuilder(java.lang.String desc,
P parms)
Constructor making a default destination key
|
Modifier and Type | Method and Description |
---|---|
abstract ModelBuilder.BuilderVisibility |
builderVisibility()
Visibility for this algo: is it always visible, is it beta (always visible but with a note in the UI)
or is it experimental (hidden by default, visible in the UI if the user gives an "experimental" flag
at startup).
|
abstract hex.ModelCategory[] |
can_build()
List containing the categories of models that this builder can
build.
|
protected boolean |
canBeDone()
Whether the Job is done after building the model itself, or whether there's extra work to be done
Override the Job's behavior here
N-fold CV jobs should not mark the job as finished, we do this explicitly in computeCrossValidation
|
void |
cancel()
Signal cancellation of this job.
|
void |
checkDistributions() |
protected void |
checkMemoryFootPrint()
Override this method to call error() if the model is expected to not fit in memory, and say why
|
void |
clearInitState()
Clear whatever was done by init() so it can be run again.
|
Job<M> |
computeCrossValidation()
Default naive (serial) implementation of N-fold cross-validation
|
protected boolean |
computePriorClassDistribution() |
static ModelBuilder |
createModelBuilder(java.lang.String algo)
Factory method to create a ModelBuilder instance of the correct class given the algo name.
|
protected boolean |
deleteProgressKey() |
int |
error_count() |
java.lang.String |
getAlgo() |
static java.lang.String |
getAlgo(java.lang.Class<? extends ModelBuilder> clz) |
static java.lang.String |
getAlgo(Model model)
Get the algo name for the given Model.
|
static java.lang.String |
getAlgoFullName(java.lang.String algo)
Get the algo full name for the given algo.
|
static java.lang.Class<? extends ModelBuilder> |
getModelBuilder(java.lang.String name)
Get the ModelBuilder class for the given algo name.
|
static java.util.Map<java.lang.String,java.lang.Class<? extends ModelBuilder>> |
getModelBuilders()
Get a Map of all algo names to their ModelBuilder classes.
|
static java.lang.Class<? extends Model> |
getModelClass(java.lang.String name)
Get the Model class for the given algo name.
|
boolean |
hasFoldCol() |
boolean |
hasOffsetCol() |
boolean |
hasWeightCol() |
protected void |
ignoreBadColumns(int npredictors,
boolean expensive)
Ignore constant columns, columns with all NAs and strings.
|
protected boolean |
ignoreConstColumns() |
protected boolean |
ignoreStringColumns() |
void |
init(boolean expensive)
Initialize the ModelBuilder, validating all arguments and preparing the
training frame.
|
boolean |
isClassifier() |
boolean |
isSupervised() |
void |
modifyParmsForCrossValidationMainModel(int N,
Key<Model>[] cvModelKeys)
Override for model-specific checks / modifications to _parms for the main model during N-fold cross-validation.
|
void |
modifyParmsForCrossValidationSplits(int i,
int N,
Key<Model> model_id)
Override with model-specific checks / modifications to _parms for N-fold cross-validation splits.
|
int |
nclasses() |
boolean |
nFoldCV()
Whether n-fold cross-validation is done
|
int |
numSpecialCols() |
protected abstract long |
progressUnits() |
static void |
registerModelBuilder(java.lang.String name,
java.lang.String full_name,
java.lang.Class<? extends ModelBuilder> clz)
Register a ModelBuilder, assigning it an algo name.
|
Vec |
response()
Train response vector.
|
protected double |
responseMean()
Compute the (weighted) mean of the response (subtracting possible offset terms)
|
abstract ModelBuilderSchema |
schema()
Externally visible default schema
TODO: this is in the wrong layer: the internals should not know anything about the schemas!!!
This puts a reverse edge into the dependency graph.
|
protected int |
separateFeatureVecs()
Find and set response/weights/offset/fold and put them all in the end,
|
Frame |
train()
Training frame: derived from the parameter's training frame, excluding
all ignored columns, all constant and bad columns, perhaps flipping the
response column to an Categorical, etc.
|
Job<M> |
trainModel()
Method to launch training of a Model, based on its parameters.
|
protected abstract Job<M> |
trainModelImpl(long progressUnits,
boolean restartTimer)
Model-specific implementation of model training
|
protected void |
updateModelOutput()
Temporary HACK to store the ModelBuilders's state and start/end/run time in the model's output
This won't be necessary once both the ModelBuilder and the Model point to a shared Job(State) object in the DKV.
|
Frame |
valid()
Validation frame: derived from the parameter's validation frame, excluding
all ignored columns, all constant and bad columns, perhaps flipping the
response column to a Categorical, etc.
|
Vec |
vresponse()
Validation response vector.
|
block, cancel, checksum_impl, clearValidationErrors, createProgressKey, dest, done, done, error_count_or_uninitialized, error, failed, get, hide, info, isCancelledOrCrashed, isDone, isRunning, isRunning, isStopped, jobKey, jobs, message, msec, progress_msg, progress, remove_impl, start, update, update, update, update, updateValidationMessages, validationErrors, warn
checksum, getBinarySerializer, getPublishedKeys, remove, remove, remove, remove
clone, frozenType, read_impl, read, readExternal, readJSON_impl, readJSON, toJsonString, write_impl, write, writeExternal, writeJSON_impl, writeJSON
public P extends Model.Parameters _parms
protected transient Frame _train
protected transient Frame _valid
protected transient Vec _response
protected transient Vec _vresponse
protected transient Vec _offset
protected transient Vec _weights
protected transient Vec _fold
protected int _nclass
public ModelBuilder(P ignore)
public ModelBuilder(java.lang.String desc, P parms)
public final Frame train()
public final Frame valid()
public Vec response()
public Vec vresponse()
protected double responseMean()
public static void registerModelBuilder(java.lang.String name, java.lang.String full_name, java.lang.Class<? extends ModelBuilder> clz)
public static java.util.Map<java.lang.String,java.lang.Class<? extends ModelBuilder>> getModelBuilders()
public static java.lang.Class<? extends ModelBuilder> getModelBuilder(java.lang.String name)
public static java.lang.Class<? extends Model> getModelClass(java.lang.String name)
public static java.lang.String getAlgo(Model model)
public static java.lang.String getAlgoFullName(java.lang.String algo)
public java.lang.String getAlgo()
public static java.lang.String getAlgo(java.lang.Class<? extends ModelBuilder> clz)
public abstract ModelBuilderSchema schema()
public static ModelBuilder createModelBuilder(java.lang.String algo)
protected void updateModelOutput()
public final Job<M> trainModel()
protected abstract Job<M> trainModelImpl(long progressUnits, boolean restartTimer)
progressUnits
- Number of progress units (each advances the Job's progress bar by a bit)restartTimer
- protected abstract long progressUnits()
protected boolean canBeDone()
public void cancel()
Job
The job will be switched to state Job.JobState.CANCELLED
which signals that
the job was cancelled by a user.
public Job<M> computeCrossValidation()
public void modifyParmsForCrossValidationSplits(int i, int N, Key<Model> model_id)
i
- which model index [0...N-1]N
- Total number of cross-validation foldspublic void modifyParmsForCrossValidationMainModel(int N, Key<Model>[] cvModelKeys)
N
- Total number of cross-validation foldsprotected boolean deleteProgressKey()
public boolean nFoldCV()
public abstract hex.ModelCategory[] can_build()
public abstract ModelBuilder.BuilderVisibility builderVisibility()
public void clearInitState()
public boolean isSupervised()
public boolean hasOffsetCol()
public boolean hasWeightCol()
public boolean hasFoldCol()
public int numSpecialCols()
public int nclasses()
public final boolean isClassifier()
protected int separateFeatureVecs()
protected boolean ignoreStringColumns()
protected boolean ignoreConstColumns()
protected void ignoreBadColumns(int npredictors, boolean expensive)
npredictors
- expensive
- protected void checkMemoryFootPrint()
protected boolean computePriorClassDistribution()
public void init(boolean expensive)
expensive
is false; it will be called once again at the start of
model building trainModel()
with expensive set to true.
The incoming training frame (and validation frame) will have ignored columns dropped out, plus whatever work the parent init did.
NOTE: The front end initially calls this through the parameters validation
endpoint with no training_frame, so each subclass's init()
method
has to work correctly with the training_frame missing.
Job.updateValidationMessages()
public void checkDistributions()