#!/usr/bin/env python
# -*- encoding: utf-8 -*-
#
# Copyright 2016 H2O.ai; Apache License Version 2.0 (see LICENSE for details)
#
"""Shared utilities used by various classes, all placed here to avoid circular imports.
This file INTENTIONALLY has NO module dependencies!
TODO: clean up this file that turned into a waste bin over the years:
- split this into more specific modules.
- utility modules should have a specific name to limit the scope of the garbage we put in (like waste sorting).
- utility modules should if possible be placed under appropriate parent module
(e.g. model/mojo related utility functions should go under h2o.model)
- utility functions used ONLY in tests should go to test utilities! no reason to export those to end users!
- same for model_utils.py nearby
"""
from .compatibility import * # NOQA
import csv
import contextlib
import io
import itertools
import os
import re
import shutil
import string
import subprocess
import sys
import tempfile
import zipfile
try:
from contextlib import AbstractContextManager
except ImportError:
import abc
class AbstractContextManager(metaclass=abc.ABCMeta):
@classmethod
def __subclasshook__(cls, C):
if cls is AbstractContextManager:
return all(any(m in SC.__dict__ for SC in C.__mro__) for m in ("__enter__", "__exit__"))
return NotImplemented
from h2o.backend.server import H2OLocalServer
from h2o.exceptions import H2OValueError
from h2o.utils.typechecks import assert_is_type, is_type, numeric
from h2o.utils.threading import local_env
_id_ctr = 0
# The set of characters allowed in frame IDs. Since frame ids are used within REST API urls, they may
# only contain characters allowed within the "segment" part of the URL (see RFC 3986). Additionally, we
# forbid all characters that are declared as "illegal" in Key.java.
_id_allowed_characters = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~")
__all__ = ('mojo_predict_csv', 'mojo_predict_pandas')
class List(list):
"""a list accepting attributes"""
pass
class LookupSeq(tuple):
"""
An immutable sequence implementation (actually a tuple) optimized for fast lookups.
Some code needs both random/indexed access on large lists and do many lookups `elem in my_list` (e.g. in a loop),
it is recommended to use this class in that case to avoid forgetting to build or use a set every time we need a lookup.
Note that this list is read-only as we don't want to have to synchronize the backed set used for the lookups.
"""
def __new__(cls, seq=()):
"""need to implement __new__ to be able to extend tuple (not necessary for list)"""
return super(LookupSeq, cls).__new__(cls, seq)
def __init__(self, seq=()):
self.__set = frozenset(self) # lookup functions backed by a set
def __contains__(self, item):
return item in self.__set
def set(self):
"""
use this for arithmetic operations on the elements to avoid confusion.
We still want this to behave like a list for the most part,
and this is slightly faster than building a set from the list itself.
"""
return self.__set
def _py_tmp_key(append):
global _id_ctr
_id_ctr += 1
return "py_" + str(_id_ctr) + append
def check_frame_id(frame_id):
check_id(frame_id, "H2OFrame")
def check_id(id, type):
"""Check that the provided id is valid in Rapids language."""
if id is None:
return
if id.strip() == "":
raise H2OValueError("%s id cannot be an empty string: %r" % (type, id))
for i, ch in enumerate(id):
# '$' character has special meaning at the beginning of the string; and prohibited anywhere else
if ch == "$" and i == 0: continue
if ch not in _id_allowed_characters:
raise H2OValueError("Character '%s' is illegal in %s id: %s" % (ch, type, id))
if re.match(r"-?[0-9]", id):
raise H2OValueError("%s id cannot start with a number: %s" % (type, id))
def temp_ctr():
return _id_ctr
def is_module_available(mod):
if local_env(mod+"_disabled"): # fast track if module is explicitly disabled
return False
if mod in sys.modules and sys.modules[mod] is not None: # fast track + safer in unusual environments
return True
import importlib.util
return importlib.util.find_spec(mod) is not None
def can_use_pandas():
return is_module_available('pandas')
def can_use_datatable():
return is_module_available('datatable') and sys.version_info.major == 3 and sys.version_info.minor <= 9
def can_use_polars():
return is_module_available('polars') and sys.version_info.major == 3 and sys.version_info.minor > 9
def can_use_pyarrow():
if can_use_pandas() and sys.version_info.minor > 9:
import pandas
return is_module_available('pyarrow') and sys.version_info.major == 3 and sys.version_info.minor > 9 and \
sys.version_info.major == 3 and float(pandas.__version__[0]) >= 1
else:
return False
def can_use_numpy():
return is_module_available('numpy')
_url_safe_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"
_url_chars_map = [chr(i) if chr(i) in _url_safe_chars else "%%%02X" % i for i in range(256)]
def url_encode(s):
# Note: type cast str(s) will not be needed once all code is made compatible
return "".join(_url_chars_map[c] for c in bytes_iterator(s))
def quote(s):
return url_encode(s)
def clamp(x, xmin, xmax):
"""Return the value of x, clamped from below by `xmin` and from above by `xmax`."""
return max(xmin, min(x, xmax))
def _gen_header(cols):
return ["C" + str(c) for c in range(1, cols + 1, 1)]
def stringify_dict(d):
return stringify_list(["{'key': %s, 'value': %s}" % (_quoted(k), v) for k, v in d.items()])
def stringify_dict_as_map(d):
return "{%s}" % ",".join(["%s: %s" % (_quoted(k), stringify_object(v, stringify_dict_as_map)) for k, v in d.items()])
def stringify_list(arr):
return "[%s]" % ",".join(stringify_list(item) if isinstance(item, list) else _str(item)
for item in arr)
def stringify_object(o, dict_function=stringify_dict):
if isinstance(o, dict):
return dict_function(o)
elif isinstance(o, list):
return stringify_list(o)
else:
return _str(o)
def _str(item):
return _str_tuple(item) if isinstance(item, tuple) else str(item)
def _str_tuple(t):
return "{%s}" % ",".join(["%s: %s" % (ti[0], _str(ti[1])) for ti in zip(list(string.ascii_lowercase), t)])
def _is_list(l):
return isinstance(l, (tuple, list))
def _is_str_list(l):
return is_type(l, [str])
def _is_num_list(l):
return is_type(l, [numeric])
def _is_list_of_lists(o):
return any(isinstance(l, (tuple, list)) for l in o)
def _is_fr(o):
return o.__class__.__name__ == "H2OFrame" # hack to avoid circular imports
def _quoted(key):
if key is None: return "\"\""
# mimic behavior in R to replace "%" and "&" characters, which break the call to /Parse, with "."
# key = key.replace("%", ".")
# key = key.replace("&", ".")
is_quoted = len(re.findall(r'\"(.+?)\"', key)) != 0
key = key if is_quoted else '"' + key + '"'
return key
def _locate(path):
"""Search for a relative path and turn it into an absolute path.
This is handy when hunting for data files to be passed into h2o and used by import file.
Note: This function is for unit testing purposes only.
Parameters
----------
path : str
Path to search for
:return: Absolute path if it is found. None otherwise.
"""
tmp_dir = os.path.realpath(os.getcwd())
possible_result = os.path.join(tmp_dir, path)
while True:
if os.path.exists(possible_result):
return possible_result
next_tmp_dir = os.path.dirname(tmp_dir)
if next_tmp_dir == tmp_dir:
raise ValueError("File not found: " + path)
tmp_dir = next_tmp_dir
possible_result = os.path.join(tmp_dir, path)
def _colmean(column):
"""Return the mean of a single-column frame."""
assert column.ncols == 1
return column.mean(return_frame=True).flatten()
def get_human_readable_bytes(size):
"""
Convert given number of bytes into a human readable representation, i.e. add prefix such as kb, Mb, Gb,
etc. The `size` argument must be a non-negative integer.
:param size: integer representing byte size of something
:return: string representation of the size, in human-readable form
"""
if size == 0: return "0"
if size is None: return ""
assert_is_type(size, int)
assert size >= 0, "`size` cannot be negative, got %d" % size
suffixes = "PTGMk"
maxl = len(suffixes)
for i in range(maxl + 1):
shift = (maxl - i) * 10
if size >> shift == 0: continue
ndigits = 0
for nd in [3, 2, 1]:
if size >> (shift + 12 - nd * 3) == 0:
ndigits = nd
break
if ndigits == 0 or size == (size >> shift) << shift:
rounded_val = str(size >> shift)
else:
rounded_val = "%.*f" % (ndigits, size / (1 << shift))
return "%s %sb" % (rounded_val, suffixes[i] if i < maxl else "")
def get_human_readable_time(time_ms):
"""
Convert given duration in milliseconds into a human-readable representation, i.e. hours, minutes, seconds,
etc. More specifically, the returned string may look like following:
1 day 3 hours 12 mins
3 days 0 hours 0 mins
8 hours 12 mins
34 mins 02 secs
13 secs
541 ms
In particular, the following rules are applied:
* milliseconds are printed only if the duration is less than a second;
* seconds are printed only if the duration is less than an hour;
* for durations greater than 1 hour we print days, hours and minutes keeping zeros in the middle (i.e. we
return "4 days 0 hours 12 mins" instead of "4 days 12 mins").
:param time_ms: duration, as a number of elapsed milliseconds.
:return: human-readable string representation of the provided duration.
"""
millis = time_ms % 1000
secs = (time_ms // 1000) % 60
mins = (time_ms // 60000) % 60
hours = (time_ms // 3600000) % 24
days = (time_ms // 86400000)
res = ""
if days > 1:
res += "%d days" % days
elif days == 1:
res += "1 day"
if hours > 1 or (hours == 0 and res):
res += " %d hours" % hours
elif hours == 1:
res += " 1 hour"
if mins > 1 or (mins == 0 and res):
res += " %d mins" % mins
elif mins == 1:
res += " 1 min"
if days == 0 and hours == 0:
res += " %02d secs" % secs
if not res:
res = " %d ms" % millis
return res.strip()
def normalize_slice(s, total):
"""
Return a "canonical" version of slice ``s``.
:param slice s: the original slice expression
:param total int: total number of elements in the collection sliced by ``s``
:return slice: a slice equivalent to ``s`` but not containing any negative indices or Nones.
"""
newstart = 0 if s.start is None else max(0, s.start + total) if s.start < 0 else min(s.start, total)
newstop = total if s.stop is None else max(0, s.stop + total) if s.stop < 0 else min(s.stop, total)
newstep = 1 if s.step is None else s.step
return slice(newstart, newstop, newstep)
def slice_is_normalized(s):
"""Return True if slice ``s`` in "normalized" form."""
return (s.start is not None and s.stop is not None and s.step is not None and s.start <= s.stop)
gen_header = _gen_header
py_tmp_key = _py_tmp_key
locate = _locate
quoted = _quoted
is_list = _is_list
is_fr = _is_fr
is_list_of_lists = _is_list_of_lists
is_num_list = _is_num_list
is_str_list = _is_str_list
gen_model_file_name = "h2o-genmodel.jar"
h2o_predictor_class = "hex.genmodel.tools.PredictCsv"
[docs]def mojo_predict_pandas(dataframe, mojo_zip_path, genmodel_jar_path=None, classpath=None, java_options=None,
verbose=False, setInvNumNA=False, predict_contributions=False, predict_calibrated=False):
"""
MOJO scoring function to take a Pandas frame and use MOJO model as zip file to score.
:param dataframe: Pandas frame to score.
:param mojo_zip_path: Path to MOJO zip downloaded from H2O.
:param genmodel_jar_path: Optional, path to genmodel jar file. If None (default) then the h2o-genmodel.jar in the same
folder as the MOJO zip will be used.
:param classpath: Optional, specifies custom user defined classpath which will be used when scoring. If None
(default) then the default classpath for this MOJO model will be used.
:param java_options: Optional, custom user defined options for Java. By default ``-Xmx4g`` is used.
:param verbose: Optional, if True, then additional debug information will be printed. False by default.
:param predict_contributions: if True, then return prediction contributions instead of regular predictions
(only for tree-based models).
:param predict_calibrated: if true, then return calibrated probabilities in addition to the predicted probabilities.
:return: Pandas frame with predictions
"""
tmp_dir = tempfile.mkdtemp()
try:
if not can_use_pandas():
raise RuntimeError('Cannot import pandas')
import pandas
assert_is_type(dataframe, pandas.DataFrame)
input_csv_path = os.path.join(tmp_dir, 'input.csv')
prediction_csv_path = os.path.join(tmp_dir, 'prediction.csv')
dataframe.to_csv(input_csv_path)
mojo_predict_csv(input_csv_path=input_csv_path, mojo_zip_path=mojo_zip_path,
output_csv_path=prediction_csv_path, genmodel_jar_path=genmodel_jar_path,
classpath=classpath, java_options=java_options, verbose=verbose, setInvNumNA=setInvNumNA,
predict_contributions=predict_contributions, predict_calibrated=predict_calibrated)
return pandas.read_csv(prediction_csv_path)
finally:
shutil.rmtree(tmp_dir)
[docs]def mojo_predict_csv(input_csv_path, mojo_zip_path, output_csv_path=None, genmodel_jar_path=None, classpath=None,
java_options=None, verbose=False, setInvNumNA=False,
predict_contributions=False, predict_calibrated=False,
extra_cmd_args=None):
"""
MOJO scoring function to take a CSV file and use MOJO model as zip file to score.
:param input_csv_path: Path to input CSV file.
:param mojo_zip_path: Path to MOJO zip downloaded from H2O.
:param output_csv_path: Optional, name of the output CSV file with computed predictions. If None (default), then
predictions will be saved as prediction.csv in the same folder as the MOJO zip.
:param genmodel_jar_path: Optional, path to genmodel jar file. If None (default) then the h2o-genmodel.jar in the same
folder as the MOJO zip will be used.
:param classpath: Optional, specifies custom user defined classpath which will be used when scoring. If None
(default) then the default classpath for this MOJO model will be used.
:param java_options: Optional, custom user defined options for Java. By default ``-Xmx4g -XX:ReservedCodeCacheSize=256m`` is used.
:param verbose: Optional, if True, then additional debug information will be printed. False by default.
:param predict_contributions: if True, then return prediction contributions instead of regular predictions
(only for tree-based models).
:param predict_calibrated: if true, then return calibrated probabilities in addition to the predicted probabilities.
:param extra_cmd_args: Optional, a list of additional arguments to append to genmodel.jar's command line.
:return: List of computed predictions
"""
default_java_options = '-Xmx4g -XX:ReservedCodeCacheSize=256m'
prediction_output_file = 'prediction.csv'
# Checking java
java = H2OLocalServer._find_java()
H2OLocalServer._check_java(java=java, verbose=verbose)
# Ensure input_csv exists
if verbose:
print("input_csv:\t%s" % input_csv_path)
if not os.path.isfile(input_csv_path):
raise RuntimeError("Input csv cannot be found at %s" % input_csv_path)
# Ensure mojo_zip exists
mojo_zip_path = os.path.abspath(mojo_zip_path)
if verbose:
print("mojo_zip:\t%s" % mojo_zip_path)
if not os.path.isfile(mojo_zip_path):
raise RuntimeError("MOJO zip cannot be found at %s" % mojo_zip_path)
parent_dir = os.path.dirname(mojo_zip_path)
# Set output_csv if necessary
if output_csv_path is None:
output_csv_path = os.path.join(parent_dir, prediction_output_file)
# Set path to h2o-genmodel.jar if necessary and check it's valid
if genmodel_jar_path is None:
genmodel_jar_path = os.path.join(parent_dir, gen_model_file_name)
if verbose:
print("genmodel_jar:\t%s" % genmodel_jar_path)
if not os.path.isfile(genmodel_jar_path):
raise RuntimeError("Genmodel jar cannot be found at %s" % genmodel_jar_path)
if verbose and output_csv_path is not None:
print("output_csv:\t%s" % output_csv_path)
# Set classpath if necessary
if classpath is None:
classpath = genmodel_jar_path
if verbose:
print("classpath:\t%s" % classpath)
# Set java_options if necessary
if java_options is None:
java_options = default_java_options
if verbose:
print("java_options:\t%s" % java_options)
# Construct command to invoke java
cmd = [java]
for option in java_options.split(' '):
cmd += [option]
cmd += ["-cp", classpath, h2o_predictor_class, "--mojo", mojo_zip_path, "--input", input_csv_path,
'--output', output_csv_path, '--decimal']
if setInvNumNA:
cmd.append('--setConvertInvalidNum')
if predict_contributions:
cmd.append('--predictContributions')
if predict_calibrated:
cmd.append('--predictCalibrated')
if extra_cmd_args:
cmd += extra_cmd_args
if verbose:
cmd_str = " ".join(cmd)
print("java cmd:\t%s" % cmd_str)
# invoke the command
subprocess.check_call(cmd, shell=False)
# load predictions in form of a dict
with open(output_csv_path) as csv_file:
result = list(csv.DictReader(csv_file))
return result
class InMemoryZipArch(object):
def __init__(self, file_name = None, compression = zipfile.ZIP_DEFLATED):
self._data = io.BytesIO()
self._arch = zipfile.ZipFile(self._data, "w", compression, False)
self._file_name = file_name
def append(self, filename_in_zip, file_contents):
self._arch.writestr(filename_in_zip, file_contents)
return self
def write_to_file(self, filename):
# Mark the files as having been created on Windows so that
# Unix permissions are not inferred as 0000
for zfile in self._arch.filelist:
zfile.create_system = 0
self._arch.close()
with open(filename, 'wb') as f:
f.write(self._data.getvalue())
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if self._file_name is None:
return
self.write_to_file(self._file_name)
@contextlib.contextmanager
def as_resource(o):
if isinstance(o, AbstractContextManager):
with o as res:
yield res
else:
yield o