Source code for h2o.model.confusion_matrix

# -*- encoding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals

from h2o.two_dim_table import H2OTwoDimTable
from h2o.utils.compatibility import *  # NOQA
from h2o.utils.typechecks import assert_is_type


[docs]class ConfusionMatrix(object): ROUND = 4 # round count_errs / sum def __init__(self, cm, domains=None, table_header=None): assert_is_type(cm, list) if len(cm) == 2: cm = list(zip(*cm)) # transpose if 2x2 nclass = len(cm) class_errs = [0] * nclass class_sums = [0] * nclass class_err_strings = [0] * nclass cell_values = [[0] * (1 + nclass)] * (1 + nclass) totals = [sum(c) for c in cm] total_errs = 0 for i in range(nclass): class_errs[i] = sum([v[i] for v in cm[:i] + cm[(i + 1):]]) total_errs += class_errs[i] class_sums[i] = sum([v[i] for v in cm]) # row sums class_err_strings[i] = \ " (" + str(class_errs[i]) + "/" + str(class_sums[i]) + ")" class_errs[i] = float("nan") if class_sums[i] == 0 else round(class_errs[i] / class_sums[i], self.ROUND) # and the cell_values are cell_values[i] = [v[i] for v in cm] + [str(class_errs[i])] + [class_err_strings[i]] # tally up the totals class_errs += [sum(class_errs)] totals += [sum(class_sums)] class_err_strings += [" (" + str(total_errs) + "/" + str(totals[-1]) + ")"] class_errs[-1] = float("nan") if totals[-1] == 0 else round(total_errs / totals[-1], self.ROUND) # do the last row of cell_values ... the "totals" row cell_values[-1] = totals[0:-1] + [str(class_errs[-1])] + [class_err_strings[-1]] if table_header is None: table_header = "Confusion Matrix (Act/Pred)" col_header = [""] # no column label for the "rows" column if domains is not None: import copy row_header = copy.deepcopy(domains) col_header += copy.deepcopy(domains) else: row_header = [str(i) for i in range(nclass)] col_header += [str(i) for i in range(nclass)] row_header += ["Total"] col_header += ["Error", "Rate"] for i in range(len(row_header)): cell_values[i].insert(0, row_header[i]) self.table = H2OTwoDimTable(row_header=row_header, col_header=col_header, table_header=table_header, cell_values=cell_values)
[docs] def show(self): """Print the confusion matrix into the console.""" self.table.show()
def __repr__(self): self.show() return ""
[docs] def to_list(self): """Convert this confusion matrix into a 2x2 plain list of values.""" return [[int(self.table.cell_values[0][1]), int(self.table.cell_values[0][2])], [int(self.table.cell_values[1][1]), int(self.table.cell_values[1][2])]]
[docs] @staticmethod def read_cms(cms=None, domains=None): """Read confusion matrices from the list of sources (?).""" assert_is_type(cms, [list]) return [ConfusionMatrix(cm, domains) for cm in cms]