Train Stacked Ensemble Model in Sparkling Water¶
Stacked Ensemble is a supervised machine learning algorithm that finds an optimal combination of a collection of prediction algorithms (base models). For further details about the algorithm and its parameters see H2O-3 documentation.
Sparkling Water provides API in Scala and Python for Stacked Ensemble. The following sections describe how to utilize Stacked Ensemble in both languages. See also Parameters of H2OStackedEnsemble.
- Scala
- Python
First, let’s start Sparkling Shell as
./bin/sparkling-shell
Start H2O cluster inside the Spark environment
import ai.h2o.sparkling._
import java.net.URI
val hc = H2OContext.getOrCreate()
Parse the data using H2O and convert them to Spark Frame
import org.apache.spark.SparkFiles
spark.sparkContext.addFile("https://raw.githubusercontent.com/h2oai/sparkling-water/master/examples/smalldata/prostate/prostate.csv")
val rawSparkDF = spark.read.option("header", "true").option("inferSchema", "true").csv(SparkFiles.get("prostate.csv"))
val dataset = rawSparkDF.withColumn("CAPSULE", $"CAPSULE" cast "string")
Setup the algorithms the StackedEnsemble will operate with. StackedEnsemble will automatically train the corresponding (base) models and pass them to H2O backend when needed. There are currently two options how a meta-learner in StackedEnsemble combines the base models. It either utilizes cross validated predictions or uses a blending frame. In the former case, it’s important to keep the same folding across the base models and set setKeepCrossValidationPredictions to true as the cross-validated predicted values will be used by meta-learner. Furthermore, as the Stacked Ensemble combines the base models inside an H2O backend the base models have to be available there as well and therefore setKeepBinaryModels has to be set to true too.
import ai.h2o.sparkling.ml.algos.{H2ODRF, H2OGBM, H2OStackedEnsemble}
val drf = new H2ODRF()
.setLabelCol("CAPSULE")
.setNfolds(5)
.setFoldAssignment("Modulo")
.setKeepBinaryModels(true)
.setKeepCrossValidationPredictions(true)
val gbm = new H2OGBM()
.setLabelCol("CAPSULE")
.setNfolds(5)
.setFoldAssignment("Modulo")
.setKeepBinaryModels(true)
.setKeepCrossValidationPredictions(true)
Then, specify the algorithms when setting up the StackedEnsemble and train it.
val ensemble = new H2OStackedEnsemble()
.setBaseAlgorithms(Array(drf, gbm))
.setLabelCol("CAPSULE")
ensemble.fit(dataset)
You can also get raw model details by calling the getModelDetails() method available on the model as:
ensembleModel.getModelDetails()
Run Predictions
ensembleModel.transform(testingDF).show(false)