Source code for id3.export

from sklearn.externals import six
from .splitter import SplitRecord, CalcRecord
import numpy as np


class DotTree():

    def __init__(self):
        self.dot_tree = ""
        self.closed = False

    def write(self, content):
        if not self.closed:
            self.dot_tree += content

    def close(self):
        self.closed = True

    def to_string(self):
        return self.dot_tree


def _extract_class_count(node):
    if node.item_count is not None:
        items, counts = node.item_count
        if counts.size == 0:
            return "(0)"
        elif counts.shape[0] == 1:
            return "({})".format(counts[0])
        else:
            max_count = np.max(counts)
            incorrect_count = np.sum(counts) - max_count
            return "({}/{})".format(max_count, incorrect_count)
    else:
        return ""


def _extract_edge_value(tree, edge):
    ft_idx = edge.calc_record.feature_idx
    split_type = edge.calc_record.split_type
    val = edge.value_encoded
    pivot = edge.calc_record.pivot
    if split_type is CalcRecord.NUM:
        if val == SplitRecord.GREATER:
            return ">{0:.2f}".format(pivot)
        else:
            return "<={0:.2f}".format(pivot)
    elif tree.X_encoders is not None:
        value = tree.X_encoders[ft_idx].single_inv_transform(val)
        if isinstance(value, np.bytes_):
            return value.decode('UTF-8')
        else:
            return value
    else:
        return val


[docs]def export_text(decision_tree, feature_names=None): """Export a decision tree in WEKA like string format. Parameters ---------- decision_tree : decision tree classifier feature_names : list of strings, optional (default=None) Names of each of the features. Returns ------- ret : string """ max_depth = 500 def build_string(node, indent, depth): ret = '' if node is None or depth > max_depth: return '' if node.is_feature: ret += '\n' template = '| ' * indent if feature_names is None: template += str(node.details.feature_idx) else: template += feature_names[node.details.feature_idx] template += ' {}' for child in node.children: edge_value = _extract_edge_value(decision_tree, child[1]) ret += template.format(edge_value) ret += build_string(child[0], indent + 1, depth + 1) else: value = decision_tree.y_encoder.single_inv_transform(node.value) if isinstance(value, np.bytes_): value = value.decode('UTF-8') ret += ': {} {} \n'.format(value, _extract_class_count(node)) return ret return build_string(decision_tree.root, 0, 0)
[docs]def export_graphviz(decision_tree, out_file=DotTree(), feature_names=None, extensive=False): """Export a decision tree in DOT format. This function generates a GraphViz representation of the decision tree, which is then written into `out_file`. Once exported, graphical renderings can be generated using, for example: $ dot -Tpdf tree.dot -o tree.pdf (PDF format) $ dot -Tpng tree.dot -o tree.png (PNG format) Parameters ---------- decision_tree : decision tree classifier The decision tree to be exported to GraphViz. out_file : string, optional (default=DotTree) Name of the output file. If ``None``, the result is returned as a string. feature_names : list of strings, optional (default=None) Names of each of the features. extensive : displays aditional information, optional (default=False) Returns ------- dot_data : string String representation of the input tree in GraphViz dot format. """ ranks = {} node_ids = [] max_depth = 500 def _recurse_tree(node, node_id=0, edge=None, parent=None, depth=0): depth += 1 if depth > max_depth: return node_ids.append(_get_next_id()) out_file.write(_node_to_dot(node, node_id, parent, edge, depth)) for child, edge in node.children: _recurse_tree(child, _get_next_id(), edge, node_id, depth) def _get_next_id(): if len(node_ids) == 0: return 0 else: return node_ids[-1] + 1 def _node_to_dot(node, n_id=0, parent=None, edge=None, depth=0): """Get a Node objects representation in dot format. """ node_repr = [] if str(depth) not in ranks: ranks[str(depth)] = [] ranks[str(depth)].append(str(n_id)) node_repr.append(('\"{}\" [shape=box, style=filled, label=\"{}\", ' 'weight={}]\n') .format(n_id, _extract_node_info(node), depth)) if parent is not None: node_repr.append(('{} -> {} [ label = "{}"];\n') .format(parent, n_id, _extract_edge_value(decision_tree, edge))) res = "".join(node_repr) return res def _extract_node_info(node): result = "" value = "" if feature_names is not None and node.is_feature: value = str(feature_names[node.value]) elif not node.is_feature: value = (decision_tree.y_encoder .single_inv_transform(node.value)) else: value = node.value if isinstance(value, np.bytes_): value = value.decode('UTF-8') result += str(value) + "\n" if node.is_feature and extensive: class_counts = node.details.class_counts dominant_class = class_counts[np.argmax(class_counts[:, 1]), :] result += ("Info: {0:.2f}\n" .format(node.details.info)) result += ("Entropy: {0:.2f}\n" .format(node.details.entropy)) result += "Dominant class: {}\n".format(dominant_class) if not node.is_feature: result += _extract_class_count(node) + "\n" return result if not isinstance(out_file, DotTree) and six.PY3: out_file = open(out_file, 'w', encoding='utf8') elif not isinstance(out_file, DotTree): out_file = open(out_file, 'wb') out_file.write('digraph ID3_Tree {\n') _recurse_tree(decision_tree.root) for rank in sorted(ranks): out_file.write("{rank=same; ") for r in ranks[rank]: out_file.write(str(r) + ";") out_file.write("};\n") out_file.write("}") out_file.close() return out_file