sample_rate

  • Available in: GBM, DRF, XGBoost
  • Hyperparameter: yes

Description

This option is used to specify the row sampling rate (x-axis). The range is 0.0 to 1.0. Row and column sampling (sample_rate and col_sample_rate) can improve generalization and lead to lower validation and test set errors. Good general values for large datasets are around 0.7 to 0.8 (sampling 70-80 percent of the data) for both parameters, as higher values generally improve training accuracy.

For highly imbalanced classification datasets, stratified row sampling based on response class membership can help improve predictive accuracy. This is configured with sample_rate_per_class (array of ratios, one per response class in lexicographic order).

Note: If sample_rate_per_class is specified, then sample_rate will be ignored.

Example

  • r
  • python
library(h2o)
h2o.init()
# import the airlines dataset:
# This dataset is used to classify whether a flight will be delayed 'YES' or not "NO"
# original data can be found at http://www.transtats.bts.gov/
airlines <-  h2o.importFile("http://s3.amazonaws.com/h2o-public-test-data/smalldata/airlines/allyears2k_headers.zip")

# convert columns to factors
airlines["Year"] <- as.factor(airlines["Year"])
airlines["Month"] <- as.factor(airlines["Month"])
airlines["DayOfWeek"] <- as.factor(airlines["DayOfWeek"])
airlines["Cancelled"] <- as.factor(airlines["Cancelled"])
airlines['FlightNum'] <- as.factor(airlines['FlightNum'])

# set the predictor names and the response column name
predictors <- c("Origin", "Dest", "Year", "UniqueCarrier", "DayOfWeek", "Month", "Distance", "FlightNum")
response <- "IsDepDelayed"

# split into train and validation
airlines.splits <- h2o.splitFrame(data =  airlines, ratios = .8, seed = 1234)
train <- airlines.splits[[1]]
valid <- airlines.splits[[2]]

# try using the `sample_rate` parameter:
airlines.gbm <- h2o.gbm(x = predictors, y = response, training_frame = train,
                        validation_frame = valid, sample_rate =.7 ,
                        seed = 1234)

# print the AUC for the validation data
print(h2o.auc(airlines.gbm, valid = TRUE))


# Example of values to grid over for `sample_rate`
hyper_params <- list( sample_rate = c(.7, .8, 1) )

# this example uses cartesian grid search because the search space is small
# and we want to see the performance of all models. For a larger search space use
# random grid search instead: list(strategy = "RandomDiscrete")
# this GBM uses early stopping once the validation AUC doesn't improve by at least 0.01% for
# 5 consecutive scoring events
grid <- h2o.grid(x = predictors, y = response, training_frame = train, validation_frame = valid,
                 algorithm = "gbm", grid_id = "air_grid", hyper_params = hyper_params,
                 stopping_rounds = 5, stopping_tolerance = 1e-4, stopping_metric = "AUC",
                 search_criteria = list(strategy = "Cartesian"), seed = 1234)

## Sort the grid models by AUC
sortedGrid <- h2o.getGrid("air_grid", sort_by = "auc", decreasing = TRUE)
sortedGrid