Source code for h2o.utils.shared_utils

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
#
# Copyright 2016 H2O.ai;  Apache License Version 2.0 (see LICENSE for details)
#
"""Shared utilities used by various classes, all placed here to avoid circular imports.

This file INTENTIONALLY has NO module dependencies!
"""
from __future__ import absolute_import, division, print_function, unicode_literals
from .compatibility import *  # NOQA

import csv
import io
import itertools
import os
import re
import shutil
import string
import subprocess
import sys
import tempfile
import zipfile

from h2o.backend.server import H2OLocalServer
from h2o.exceptions import H2OValueError
from h2o.utils.typechecks import assert_is_type, is_type, numeric

_id_ctr = 0

# The set of characters allowed in frame IDs. Since frame ids are used within REST API urls, they may
# only contain characters allowed within the "segment" part of the URL (see RFC 3986). Additionally, we
# forbid all characters that are declared as "illegal" in Key.java.
_id_allowed_characters = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~")

__all__ = ('mojo_predict_csv', 'mojo_predict_pandas')


def _py_tmp_key(append):
    global _id_ctr
    _id_ctr += 1
    return "py_" + str(_id_ctr) + append


def check_frame_id(frame_id):
    check_id(frame_id, "H2OFrame")


def check_id(id, type):
    """Check that the provided id is valid in Rapids language."""
    if id is None:
        return
    if id.strip() == "":
        raise H2OValueError("%s id cannot be an empty string: %r" % (type, id))
    for i, ch in enumerate(id):
        # '$' character has special meaning at the beginning of the string; and prohibited anywhere else
        if ch == "$" and i == 0: continue
        if ch not in _id_allowed_characters:
            raise H2OValueError("Character '%s' is illegal in %s id: %s" % (ch, type, id))
    if re.match(r"-?[0-9]", id):
        raise H2OValueError("%s id cannot start with a number: %s" % (type, id))


def temp_ctr():
    return _id_ctr


def is_module_available(mod):
    if mod in sys.modules and sys.modules[mod] is not None:  # fast track + safer in unusual environments 
        return True
    if PY2:
        import imp
        try:
            imp.find_module(mod)
            return True
        except ImportError:
            return False
        
    import importlib.util
    return importlib.util.find_spec(mod) is not None


def can_use_pandas():
    return is_module_available('pandas')


def can_use_numpy():
    return is_module_available('numpy')


_url_safe_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"
_url_chars_map = [chr(i) if chr(i) in _url_safe_chars else "%%%02X" % i for i in range(256)]


def url_encode(s):
    # Note: type cast str(s) will not be needed once all code is made compatible
    return "".join(_url_chars_map[c] for c in bytes_iterator(s))


def quote(s):
    return url_encode(s)


def clamp(x, xmin, xmax):
    """Return the value of x, clamped from below by `xmin` and from above by `xmax`."""
    return max(xmin, min(x, xmax))


def _gen_header(cols):
    return ["C" + str(c) for c in range(1, cols + 1, 1)]


def _check_lists_of_lists(python_obj):
    # check we have a lists of flat lists
    # returns longest length of sublist
    most_cols = 1
    for l in python_obj:
        # All items in the list must be a list!
        if not isinstance(l, (tuple, list)):
            raise ValueError("`python_obj` is a mixture of nested lists and other types.")
        most_cols = max(most_cols, len(l))
        for ll in l:
            # in fact, we must have a list of flat lists!
            if isinstance(ll, (tuple, list)):
                raise ValueError("`python_obj` is not a list of flat lists!")
    return most_cols


def _handle_python_lists(python_obj, check_header):
    # convert all inputs to lol
    if _is_list_of_lists(python_obj):  # do we have a list of lists: [[...], ..., [...]] ?
        ncols = _check_lists_of_lists(python_obj)  # must be a list of flat lists, raise ValueError if not
    elif isinstance(python_obj, (list, tuple)):  # single list
        ncols = 1
        python_obj = [[e] for e in python_obj]
    else:  # scalar
        python_obj = [[python_obj]]
        ncols = 1
    # create the header
    if check_header == 1:
        header = python_obj[0]
        python_obj = python_obj[1:]
    else:
        header = _gen_header(ncols)
    # shape up the data for csv.DictWriter
    # data_to_write = [dict(list(zip(header, row))) for row in python_obj]
    return header, python_obj


def stringify_dict(d):
    return stringify_list(["{'key': %s, 'value': %s}" % (_quoted(k), v) for k, v in d.items()])


def stringify_dict_as_map(d):
    return "{%s}" % ",".join(["%s: %s" % (_quoted(k), stringify_object(v, stringify_dict_as_map)) for k, v in d.items()])


def stringify_list(arr):
    return "[%s]" % ",".join(stringify_list(item) if isinstance(item, list) else _str(item)
                             for item in arr)


def stringify_object(o, dict_function=stringify_dict):
    if isinstance(o, dict):
        return dict_function(o)
    elif isinstance(o, list):
        return stringify_list(o)
    else:
        return _str(o)


def _str(item):
    return _str_tuple(item) if isinstance(item, tuple) else str(item)


def _str_tuple(t):
    return "{%s}" % ",".join(["%s: %s" % (ti[0], _str(ti[1])) for ti in zip(list(string.ascii_lowercase), t)])


def _is_list(l):
    return isinstance(l, (tuple, list))


def _is_str_list(l):
    return is_type(l, [str])


def _is_num_list(l):
    return is_type(l, [numeric])


def _is_list_of_lists(o):
    return any(isinstance(l, (tuple, list)) for l in o)


def _handle_numpy_array(python_obj, header):
    return _handle_python_lists(python_obj.tolist(), header)


def _handle_pandas_data_frame(python_obj, header):
    data = _handle_python_lists(python_obj.values.tolist(), -1)[1]
    return list(str(c) for c in python_obj.columns), data


def _handle_python_dicts(python_obj, check_header):
    header = list(python_obj.keys()) if python_obj else _gen_header(1)
    is_valid = all(re.match(r"^[a-zA-Z_][a-zA-Z0-9_.]*$", col) for col in header)  # is this a valid header?
    if not is_valid:
        raise ValueError(
            "Did not get a valid set of column names! Must match the regular expression: ^[a-zA-Z_][a-zA-Z0-9_.]*$ ")
    for k in python_obj:  # check that each value entry is a flat list/tuple or single int, float, or string
        v = python_obj[k]
        if isinstance(v, (tuple, list)):  # if value is a tuple/list, then it must be flat
            if _is_list_of_lists(v):
                raise ValueError("Values in the dictionary must be flattened!")
        elif is_type(v, str, numeric):
            python_obj[k] = [v]
        else:
            raise ValueError("Encountered invalid dictionary value when constructing H2OFrame. Got: {0}".format(v))

    zipper = getattr(itertools, "zip_longest", None) or getattr(itertools, "izip_longest", None) or zip
    rows = list(map(list, zipper(*list(python_obj.values()))))
    data_to_write = [dict(list(zip(header, row))) for row in rows]
    return header, data_to_write


def _is_fr(o):
    return o.__class__.__name__ == "H2OFrame"  # hack to avoid circular imports


def _quoted(key):
    if key is None: return "\"\""
    # mimic behavior in R to replace "%" and "&" characters, which break the call to /Parse, with "."
    # key = key.replace("%", ".")
    # key = key.replace("&", ".")
    is_quoted = len(re.findall(r'\"(.+?)\"', key)) != 0
    key = key if is_quoted else '"' + key + '"'
    return key


def _locate(path):
    """Search for a relative path and turn it into an absolute path.
    This is handy when hunting for data files to be passed into h2o and used by import file.
    Note: This function is for unit testing purposes only.

    Parameters
    ----------
    path : str
      Path to search for

    :return: Absolute path if it is found.  None otherwise.
    """

    tmp_dir = os.path.realpath(os.getcwd())
    possible_result = os.path.join(tmp_dir, path)
    while True:
        if os.path.exists(possible_result):
            return possible_result

        next_tmp_dir = os.path.dirname(tmp_dir)
        if next_tmp_dir == tmp_dir:
            raise ValueError("File not found: " + path)

        tmp_dir = next_tmp_dir
        possible_result = os.path.join(tmp_dir, path)


def _colmean(column):
    """Return the mean of a single-column frame."""
    assert column.ncols == 1
    return column.mean(return_frame=True).flatten()


def get_human_readable_bytes(size):
    """
    Convert given number of bytes into a human readable representation, i.e. add prefix such as kb, Mb, Gb,
    etc. The `size` argument must be a non-negative integer.

    :param size: integer representing byte size of something
    :return: string representation of the size, in human-readable form
    """
    if size == 0: return "0"
    if size is None: return ""
    assert_is_type(size, int)
    assert size >= 0, "`size` cannot be negative, got %d" % size
    suffixes = "PTGMk"
    maxl = len(suffixes)
    for i in range(maxl + 1):
        shift = (maxl - i) * 10
        if size >> shift == 0: continue
        ndigits = 0
        for nd in [3, 2, 1]:
            if size >> (shift + 12 - nd * 3) == 0:
                ndigits = nd
                break
        if ndigits == 0 or size == (size >> shift) << shift:
            rounded_val = str(size >> shift)
        else:
            rounded_val = "%.*f" % (ndigits, size / (1 << shift))
        return "%s %sb" % (rounded_val, suffixes[i] if i < maxl else "")


def get_human_readable_time(time_ms):
    """
    Convert given duration in milliseconds into a human-readable representation, i.e. hours, minutes, seconds,
    etc. More specifically, the returned string may look like following:
        1 day 3 hours 12 mins
        3 days 0 hours 0 mins
        8 hours 12 mins
        34 mins 02 secs
        13 secs
        541 ms
    In particular, the following rules are applied:
        * milliseconds are printed only if the duration is less than a second;
        * seconds are printed only if the duration is less than an hour;
        * for durations greater than 1 hour we print days, hours and minutes keeping zeros in the middle (i.e. we
          return "4 days 0 hours 12 mins" instead of "4 days 12 mins").

    :param time_ms: duration, as a number of elapsed milliseconds.
    :return: human-readable string representation of the provided duration.
    """
    millis = time_ms % 1000
    secs = (time_ms // 1000) % 60
    mins = (time_ms // 60000) % 60
    hours = (time_ms // 3600000) % 24
    days = (time_ms // 86400000)

    res = ""
    if days > 1:
        res += "%d days" % days
    elif days == 1:
        res += "1 day"

    if hours > 1 or (hours == 0 and res):
        res += " %d hours" % hours
    elif hours == 1:
        res += " 1 hour"

    if mins > 1 or (mins == 0 and res):
        res += " %d mins" % mins
    elif mins == 1:
        res += " 1 min"

    if days == 0 and hours == 0:
        res += " %02d secs" % secs
    if not res:
        res = " %d ms" % millis

    return res.strip()


def print2(msg, flush=False, end="\n"):
    """
    This function exists here ONLY because Sphinx.ext.autodoc gets into a bad state when seeing the print()
    function. When in that state, autodoc doesn't display any errors or warnings, but instead completely
    ignores the "bysource" member-order option.
    """
    print(msg, end=end)
    if flush: sys.stdout.flush()


def normalize_slice(s, total):
    """
    Return a "canonical" version of slice ``s``.

    :param slice s: the original slice expression
    :param total int: total number of elements in the collection sliced by ``s``
    :return slice: a slice equivalent to ``s`` but not containing any negative indices or Nones.
    """
    newstart = 0 if s.start is None else max(0, s.start + total) if s.start < 0 else min(s.start, total)
    newstop = total if s.stop is None else max(0, s.stop + total) if s.stop < 0 else min(s.stop, total)
    newstep = 1 if s.step is None else s.step
    return slice(newstart, newstop, newstep)


def slice_is_normalized(s):
    """Return True if slice ``s`` in "normalized" form."""
    return (s.start is not None and s.stop is not None and s.step is not None and s.start <= s.stop)


gen_header = _gen_header
py_tmp_key = _py_tmp_key
locate = _locate
quoted = _quoted
is_list = _is_list
is_fr = _is_fr
handle_python_dicts = _handle_python_dicts
handle_pandas_data_frame = _handle_pandas_data_frame
handle_numpy_array = _handle_numpy_array
is_list_of_lists = _is_list_of_lists
is_num_list = _is_num_list
is_str_list = _is_str_list
handle_python_lists = _handle_python_lists
check_lists_of_lists = _check_lists_of_lists

gen_model_file_name = "h2o-genmodel.jar"
h2o_predictor_class = "hex.genmodel.tools.PredictCsv"


[docs]def mojo_predict_pandas(dataframe, mojo_zip_path, genmodel_jar_path=None, classpath=None, java_options=None, verbose=False, setInvNumNA=False, predict_contributions=False): """ MOJO scoring function to take a Pandas frame and use MOJO model as zip file to score. :param dataframe: Pandas frame to score. :param mojo_zip_path: Path to MOJO zip downloaded from H2O. :param genmodel_jar_path: Optional, path to genmodel jar file. If None (default) then the h2o-genmodel.jar in the same folder as the MOJO zip will be used. :param classpath: Optional, specifies custom user defined classpath which will be used when scoring. If None (default) then the default classpath for this MOJO model will be used. :param java_options: Optional, custom user defined options for Java. By default ``-Xmx4g`` is used. :param verbose: Optional, if True, then additional debug information will be printed. False by default. :param predict_contributions: if True, then return prediction contributions instead of regular predictions (only for tree-based models). :return: Pandas frame with predictions """ tmp_dir = tempfile.mkdtemp() try: if not can_use_pandas(): raise RuntimeError('Cannot import pandas') import pandas assert_is_type(dataframe, pandas.DataFrame) input_csv_path = os.path.join(tmp_dir, 'input.csv') prediction_csv_path = os.path.join(tmp_dir, 'prediction.csv') dataframe.to_csv(input_csv_path) mojo_predict_csv(input_csv_path=input_csv_path, mojo_zip_path=mojo_zip_path, output_csv_path=prediction_csv_path, genmodel_jar_path=genmodel_jar_path, classpath=classpath, java_options=java_options, verbose=verbose, setInvNumNA=setInvNumNA, predict_contributions=predict_contributions) return pandas.read_csv(prediction_csv_path) finally: shutil.rmtree(tmp_dir)
[docs]def mojo_predict_csv(input_csv_path, mojo_zip_path, output_csv_path=None, genmodel_jar_path=None, classpath=None, java_options=None, verbose=False, setInvNumNA=False, predict_contributions=False): """ MOJO scoring function to take a CSV file and use MOJO model as zip file to score. :param input_csv_path: Path to input CSV file. :param mojo_zip_path: Path to MOJO zip downloaded from H2O. :param output_csv_path: Optional, name of the output CSV file with computed predictions. If None (default), then predictions will be saved as prediction.csv in the same folder as the MOJO zip. :param genmodel_jar_path: Optional, path to genmodel jar file. If None (default) then the h2o-genmodel.jar in the same folder as the MOJO zip will be used. :param classpath: Optional, specifies custom user defined classpath which will be used when scoring. If None (default) then the default classpath for this MOJO model will be used. :param java_options: Optional, custom user defined options for Java. By default ``-Xmx4g -XX:ReservedCodeCacheSize=256m`` is used. :param verbose: Optional, if True, then additional debug information will be printed. False by default. :param predict_contributions: if True, then return prediction contributions instead of regular predictions (only for tree-based models). :return: List of computed predictions """ default_java_options = '-Xmx4g -XX:ReservedCodeCacheSize=256m' prediction_output_file = 'prediction.csv' # Checking java java = H2OLocalServer._find_java() H2OLocalServer._check_java(java=java, verbose=verbose) # Ensure input_csv exists if verbose: print("input_csv:\t%s" % input_csv_path) if not os.path.isfile(input_csv_path): raise RuntimeError("Input csv cannot be found at %s" % input_csv_path) # Ensure mojo_zip exists mojo_zip_path = os.path.abspath(mojo_zip_path) if verbose: print("mojo_zip:\t%s" % mojo_zip_path) if not os.path.isfile(mojo_zip_path): raise RuntimeError("MOJO zip cannot be found at %s" % mojo_zip_path) parent_dir = os.path.dirname(mojo_zip_path) # Set output_csv if necessary if output_csv_path is None: output_csv_path = os.path.join(parent_dir, prediction_output_file) # Set path to h2o-genmodel.jar if necessary and check it's valid if genmodel_jar_path is None: genmodel_jar_path = os.path.join(parent_dir, gen_model_file_name) if verbose: print("genmodel_jar:\t%s" % genmodel_jar_path) if not os.path.isfile(genmodel_jar_path): raise RuntimeError("Genmodel jar cannot be found at %s" % genmodel_jar_path) if verbose and output_csv_path is not None: print("output_csv:\t%s" % output_csv_path) # Set classpath if necessary if classpath is None: classpath = genmodel_jar_path if verbose: print("classpath:\t%s" % classpath) # Set java_options if necessary if java_options is None: java_options = default_java_options if verbose: print("java_options:\t%s" % java_options) # Construct command to invoke java cmd = [java] for option in java_options.split(' '): cmd += [option] cmd += ["-cp", classpath, h2o_predictor_class, "--mojo", mojo_zip_path, "--input", input_csv_path, '--output', output_csv_path, '--decimal'] if setInvNumNA: cmd.append('--setConvertInvalidNum') if predict_contributions: cmd.append('--predictContributions') if verbose: cmd_str = " ".join(cmd) print("java cmd:\t%s" % cmd_str) # invoke the command subprocess.check_call(cmd, shell=False) # load predictions in form of a dict with open(output_csv_path) as csv_file: result = list(csv.DictReader(csv_file)) return result
class InMemoryZipArch(object): def __init__(self, file_name = None, compression = zipfile.ZIP_DEFLATED): self._data = io.BytesIO() self._arch = zipfile.ZipFile(self._data, "w", compression, False) self._file_name = file_name def append(self, filename_in_zip, file_contents): self._arch.writestr(filename_in_zip, file_contents) return self def write_to_file(self, filename): # Mark the files as having been created on Windows so that # Unix permissions are not inferred as 0000 for zfile in self._arch.filelist: zfile.create_system = 0 self._arch.close() with open(filename, 'wb') as f: f.write(self._data.getvalue()) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): if self._file_name is None: return self.write_to_file(self._file_name)