# -*- encoding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals
from h2o.utils.shared_utils import can_use_pandas
from h2o.utils.compatibility import * # NOQA
from .model_base import ModelBase
from .metrics_base import * # NOQA
import h2o
[docs]class H2ODimReductionModel(ModelBase):
"""
Dimension reduction model, such as PCA or GLRM.
"""
[docs] def varimp(self, use_pandas=False):
"""
Return the Importance of components associcated with a pca model.
use_pandas: ``bool`` (default: ``False``).
"""
model = self._model_json["output"]
if "importance" in list(model.keys()) and model["importance"]:
vals = model["importance"].cell_values
header = model["importance"].col_header
if use_pandas and can_use_pandas():
import pandas
return pandas.DataFrame(vals, columns=header)
else:
return vals
else:
print("Warning: This model doesn't have importances of components.")
[docs] def num_iterations(self):
"""Get the number of iterations that it took to converge or reach max iterations."""
o = self._model_json["output"]
return o["model_summary"]["number_of_iterations"][0]
[docs] def objective(self):
"""Get the final value of the objective function."""
o = self._model_json["output"]
return o["model_summary"]["final_objective_value"][0]
[docs] def final_step(self):
"""Get the final step size for the model."""
o = self._model_json["output"]
return o["model_summary"]["final_step_size"][0]
[docs] def archetypes(self):
"""The archetypes (Y) of the GLRM model."""
o = self._model_json["output"]
yvals = o["archetypes"].cell_values
archetypes = []
for yidx, yval in enumerate(yvals):
archetypes.append(list(yvals[yidx])[1:])
return archetypes
[docs] def reconstruct(self, test_data, reverse_transform=False):
"""
Reconstruct the training data from the model and impute all missing values.
:param H2OFrame test_data: The dataset upon which the model was trained.
:param bool reverse_transform: Whether the transformation of the training data during model-building
should be reversed on the reconstructed frame.
:returns: the approximate reconstruction of the training data.
"""
if test_data is None or test_data.nrow == 0: raise ValueError("Must specify test data")
j = h2o.api("POST /3/Predictions/models/%s/frames/%s" % (self.model_id, test_data.frame_id),
data={"reconstruct_train": True, "reverse_transform": reverse_transform})
return h2o.get_frame(j["model_metrics"][0]["predictions"]["frame_id"]["name"])
[docs] def proj_archetypes(self, test_data, reverse_transform=False):
"""
Convert archetypes of the model into original feature space.
:param H2OFrame test_data: The dataset upon which the model was trained.
:param bool reverse_transform: Whether the transformation of the training data during model-building
should be reversed on the projected archetypes.
:returns: model archetypes projected back into the original training data's feature space.
"""
if test_data is None or test_data.nrow == 0: raise ValueError("Must specify test data")
j = h2o.api("POST /3/Predictions/models/%s/frames/%s" % (self.model_id, test_data.frame_id),
data={"project_archetypes": True, "reverse_transform": reverse_transform})
return h2o.get_frame(j["model_metrics"][0]["predictions"]["frame_id"]["name"])
[docs] def screeplot(self, type="barplot", **kwargs):
"""
Produce the scree plot.
Library ``matplotlib`` is required for this function.
:param str type: either ``"barplot"`` or ``"lines"``.
"""
# check for matplotlib. exit if absent.
is_server = kwargs.pop("server")
if kwargs:
raise ValueError("Unknown arguments %s to screeplot()" % ", ".join(kwargs.keys()))
try:
import matplotlib
if is_server: matplotlib.use('Agg', warn=False)
import matplotlib.pyplot as plt
except ImportError:
print("matplotlib is required for this function!")
return
variances = [s ** 2 for s in self._model_json['output']['importance'].cell_values[0][1:]]
plt.xlabel('Components')
plt.ylabel('Variances')
plt.title('Scree Plot')
plt.xticks(list(range(1, len(variances) + 1)))
if type == "barplot":
plt.bar(list(range(1, len(variances) + 1)), variances)
elif type == "lines":
plt.plot(list(range(1, len(variances) + 1)), variances, 'b--')
if not is_server: plt.show()