Uploaded Test files

This commit is contained in:
Batuhan Berk Başoğlu 2020-11-12 11:05:57 -05:00
parent f584ad9d97
commit 2e81cb7d99
16627 changed files with 2065359 additions and 102444 deletions

View file

@ -0,0 +1,16 @@
"""
The :mod:`sklearn.tree` module includes decision tree-based models for
classification and regression.
"""
from ._classes import BaseDecisionTree
from ._classes import DecisionTreeClassifier
from ._classes import DecisionTreeRegressor
from ._classes import ExtraTreeClassifier
from ._classes import ExtraTreeRegressor
from ._export import export_graphviz, plot_tree, export_text
__all__ = ["BaseDecisionTree",
"DecisionTreeClassifier", "DecisionTreeRegressor",
"ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz",
"plot_tree", "export_text"]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,77 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Brian Holt <bdholt1@gmail.com>
# Joel Nothman <joel.nothman@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
#
# License: BSD 3 clause
# See _criterion.pyx for implementation details.
import numpy as np
cimport numpy as np
from ._tree cimport DTYPE_t # Type of X
from ._tree cimport DOUBLE_t # Type of y, sample_weight
from ._tree cimport SIZE_t # Type for indices and counters
from ._tree cimport INT32_t # Signed 32 bit integer
from ._tree cimport UINT32_t # Unsigned 32 bit integer
cdef class Criterion:
# The criterion computes the impurity of a node and the reduction of
# impurity of a split on that node. It also computes the output statistics
# such as the mean in regression and class probabilities in classification.
# Internal structures
cdef const DOUBLE_t[:, ::1] y # Values of y
cdef DOUBLE_t* sample_weight # Sample weights
cdef SIZE_t* samples # Sample indices in X, y
cdef SIZE_t start # samples[start:pos] are the samples in the left node
cdef SIZE_t pos # samples[pos:end] are the samples in the right node
cdef SIZE_t end
cdef SIZE_t n_outputs # Number of outputs
cdef SIZE_t n_samples # Number of samples
cdef SIZE_t n_node_samples # Number of samples in the node (end-start)
cdef double weighted_n_samples # Weighted number of samples (in total)
cdef double weighted_n_node_samples # Weighted number of samples in the node
cdef double weighted_n_left # Weighted number of samples in the left node
cdef double weighted_n_right # Weighted number of samples in the right node
cdef double* sum_total # For classification criteria, the sum of the
# weighted count of each label. For regression,
# the sum of w*y. sum_total[k] is equal to
# sum_{i=start}^{end-1} w[samples[i]]*y[samples[i], k],
# where k is output index.
cdef double* sum_left # Same as above, but for the left side of the split
cdef double* sum_right # same as above, but for the right side of the split
# The criterion object is maintained such that left and right collected
# statistics correspond to samples[start:pos] and samples[pos:end].
# Methods
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
SIZE_t end) nogil except -1
cdef int reset(self) nogil except -1
cdef int reverse_reset(self) nogil except -1
cdef int update(self, SIZE_t new_pos) nogil except -1
cdef double node_impurity(self) nogil
cdef void children_impurity(self, double* impurity_left,
double* impurity_right) nogil
cdef void node_value(self, double* dest) nogil
cdef double impurity_improvement(self, double impurity) nogil
cdef double proxy_impurity_improvement(self) nogil
cdef class ClassificationCriterion(Criterion):
"""Abstract criterion for classification."""
cdef SIZE_t* n_classes
cdef SIZE_t sum_stride
cdef class RegressionCriterion(Criterion):
"""Abstract regression criterion."""
cdef double sq_sum_total

View file

@ -0,0 +1,967 @@
"""
This module defines export functions for decision trees.
"""
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Brian Holt <bdholt1@gmail.com>
# Noel Dawe <noel@dawe.me>
# Satrajit Gosh <satrajit.ghosh@gmail.com>
# Trevor Stephens <trev.stephens@gmail.com>
# Li Li <aiki.nogard@gmail.com>
# Giuseppe Vettigli <vettigli@gmail.com>
# License: BSD 3 clause
from io import StringIO
from numbers import Integral
import numpy as np
from ..utils.validation import check_is_fitted
from ..utils.validation import _deprecate_positional_args
from ..base import is_classifier
from . import _criterion
from . import _tree
from ._reingold_tilford import buchheim, Tree
from . import DecisionTreeClassifier
import warnings
def _color_brew(n):
"""Generate n colors with equally spaced hues.
Parameters
----------
n : int
The number of colors required.
Returns
-------
color_list : list, length n
List of n tuples of form (R, G, B) being the components of each color.
"""
color_list = []
# Initialize saturation & value; calculate chroma & value shift
s, v = 0.75, 0.9
c = s * v
m = v - c
for h in np.arange(25, 385, 360. / n).astype(int):
# Calculate some intermediate values
h_bar = h / 60.
x = c * (1 - abs((h_bar % 2) - 1))
# Initialize RGB with same hue & chroma as our color
rgb = [(c, x, 0),
(x, c, 0),
(0, c, x),
(0, x, c),
(x, 0, c),
(c, 0, x),
(c, x, 0)]
r, g, b = rgb[int(h_bar)]
# Shift the initial RGB values to match value and store
rgb = [(int(255 * (r + m))),
(int(255 * (g + m))),
(int(255 * (b + m)))]
color_list.append(rgb)
return color_list
class Sentinel:
def __repr__(self):
return '"tree.dot"'
SENTINEL = Sentinel()
@_deprecate_positional_args
def plot_tree(decision_tree, *, max_depth=None, feature_names=None,
class_names=None, label='all', filled=False,
impurity=True, node_ids=False,
proportion=False, rotate='deprecated', rounded=False,
precision=3, ax=None, fontsize=None):
"""Plot a decision tree.
The sample counts that are shown are weighted with any sample_weights that
might be present.
The visualization is fit automatically to the size of the axis.
Use the ``figsize`` or ``dpi`` arguments of ``plt.figure`` to control
the size of the rendering.
Read more in the :ref:`User Guide <tree>`.
.. versionadded:: 0.21
Parameters
----------
decision_tree : decision tree regressor or classifier
The decision tree to be plotted.
max_depth : int, optional (default=None)
The maximum depth of the representation. If None, the tree is fully
generated.
feature_names : list of strings, optional (default=None)
Names of each of the features.
class_names : list of strings, bool or None, optional (default=None)
Names of each of the target classes in ascending numerical order.
Only relevant for classification and not supported for multi-output.
If ``True``, shows a symbolic representation of the class name.
label : {'all', 'root', 'none'}, optional (default='all')
Whether to show informative labels for impurity, etc.
Options include 'all' to show at every node, 'root' to show only at
the top root node, or 'none' to not show at any node.
filled : bool, optional (default=False)
When set to ``True``, paint nodes to indicate majority class for
classification, extremity of values for regression, or purity of node
for multi-output.
impurity : bool, optional (default=True)
When set to ``True``, show the impurity at each node.
node_ids : bool, optional (default=False)
When set to ``True``, show the ID number on each node.
proportion : bool, optional (default=False)
When set to ``True``, change the display of 'values' and/or 'samples'
to be proportions and percentages respectively.
rotate : bool, optional (default=False)
This parameter has no effect on the matplotlib tree visualisation and
it is kept here for backward compatibility.
.. deprecated:: 0.23
``rotate`` is deprecated in 0.23 and will be removed in 0.25.
rounded : bool, optional (default=False)
When set to ``True``, draw node boxes with rounded corners and use
Helvetica fonts instead of Times-Roman.
precision : int, optional (default=3)
Number of digits of precision for floating point in the values of
impurity, threshold and value attributes of each node.
ax : matplotlib axis, optional (default=None)
Axes to plot to. If None, use current axis. Any previous content
is cleared.
fontsize : int, optional (default=None)
Size of text font. If None, determined automatically to fit figure.
Returns
-------
annotations : list of artists
List containing the artists for the annotation boxes making up the
tree.
Examples
--------
>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>> clf = tree.DecisionTreeClassifier(random_state=0)
>>> iris = load_iris()
>>> clf = clf.fit(iris.data, iris.target)
>>> tree.plot_tree(clf) # doctest: +SKIP
[Text(251.5,345.217,'X[3] <= 0.8...
"""
check_is_fitted(decision_tree)
if rotate != 'deprecated':
warnings.warn(("'rotate' has no effect and is deprecated in 0.23. "
"It will be removed in 0.25."),
FutureWarning)
exporter = _MPLTreeExporter(
max_depth=max_depth, feature_names=feature_names,
class_names=class_names, label=label, filled=filled,
impurity=impurity, node_ids=node_ids,
proportion=proportion, rotate=rotate, rounded=rounded,
precision=precision, fontsize=fontsize)
return exporter.export(decision_tree, ax=ax)
class _BaseTreeExporter:
def __init__(self, max_depth=None, feature_names=None,
class_names=None, label='all', filled=False,
impurity=True, node_ids=False,
proportion=False, rotate=False, rounded=False,
precision=3, fontsize=None):
self.max_depth = max_depth
self.feature_names = feature_names
self.class_names = class_names
self.label = label
self.filled = filled
self.impurity = impurity
self.node_ids = node_ids
self.proportion = proportion
self.rotate = rotate
self.rounded = rounded
self.precision = precision
self.fontsize = fontsize
def get_color(self, value):
# Find the appropriate color & intensity for a node
if self.colors['bounds'] is None:
# Classification tree
color = list(self.colors['rgb'][np.argmax(value)])
sorted_values = sorted(value, reverse=True)
if len(sorted_values) == 1:
alpha = 0
else:
alpha = ((sorted_values[0] - sorted_values[1])
/ (1 - sorted_values[1]))
else:
# Regression tree or multi-output
color = list(self.colors['rgb'][0])
alpha = ((value - self.colors['bounds'][0]) /
(self.colors['bounds'][1] - self.colors['bounds'][0]))
# unpack numpy scalars
alpha = float(alpha)
# compute the color as alpha against white
color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color]
# Return html color code in #RRGGBB format
return '#%2x%2x%2x' % tuple(color)
def get_fill_color(self, tree, node_id):
# Fetch appropriate color for node
if 'rgb' not in self.colors:
# Initialize colors and bounds if required
self.colors['rgb'] = _color_brew(tree.n_classes[0])
if tree.n_outputs != 1:
# Find max and min impurities for multi-output
self.colors['bounds'] = (np.min(-tree.impurity),
np.max(-tree.impurity))
elif (tree.n_classes[0] == 1 and
len(np.unique(tree.value)) != 1):
# Find max and min values in leaf nodes for regression
self.colors['bounds'] = (np.min(tree.value),
np.max(tree.value))
if tree.n_outputs == 1:
node_val = (tree.value[node_id][0, :] /
tree.weighted_n_node_samples[node_id])
if tree.n_classes[0] == 1:
# Regression
node_val = tree.value[node_id][0, :]
else:
# If multi-output color node by impurity
node_val = -tree.impurity[node_id]
return self.get_color(node_val)
def node_to_str(self, tree, node_id, criterion):
# Generate the node content string
if tree.n_outputs == 1:
value = tree.value[node_id][0, :]
else:
value = tree.value[node_id]
# Should labels be shown?
labels = (self.label == 'root' and node_id == 0) or self.label == 'all'
characters = self.characters
node_string = characters[-1]
# Write node ID
if self.node_ids:
if labels:
node_string += 'node '
node_string += characters[0] + str(node_id) + characters[4]
# Write decision criteria
if tree.children_left[node_id] != _tree.TREE_LEAF:
# Always write node decision criteria, except for leaves
if self.feature_names is not None:
feature = self.feature_names[tree.feature[node_id]]
else:
feature = "X%s%s%s" % (characters[1],
tree.feature[node_id],
characters[2])
node_string += '%s %s %s%s' % (feature,
characters[3],
round(tree.threshold[node_id],
self.precision),
characters[4])
# Write impurity
if self.impurity:
if isinstance(criterion, _criterion.FriedmanMSE):
criterion = "friedman_mse"
elif not isinstance(criterion, str):
criterion = "impurity"
if labels:
node_string += '%s = ' % criterion
node_string += (str(round(tree.impurity[node_id], self.precision))
+ characters[4])
# Write node sample count
if labels:
node_string += 'samples = '
if self.proportion:
percent = (100. * tree.n_node_samples[node_id] /
float(tree.n_node_samples[0]))
node_string += (str(round(percent, 1)) + '%' +
characters[4])
else:
node_string += (str(tree.n_node_samples[node_id]) +
characters[4])
# Write node class distribution / regression value
if self.proportion and tree.n_classes[0] != 1:
# For classification this will show the proportion of samples
value = value / tree.weighted_n_node_samples[node_id]
if labels:
node_string += 'value = '
if tree.n_classes[0] == 1:
# Regression
value_text = np.around(value, self.precision)
elif self.proportion:
# Classification
value_text = np.around(value, self.precision)
elif np.all(np.equal(np.mod(value, 1), 0)):
# Classification without floating-point weights
value_text = value.astype(int)
else:
# Classification with floating-point weights
value_text = np.around(value, self.precision)
# Strip whitespace
value_text = str(value_text.astype('S32')).replace("b'", "'")
value_text = value_text.replace("' '", ", ").replace("'", "")
if tree.n_classes[0] == 1 and tree.n_outputs == 1:
value_text = value_text.replace("[", "").replace("]", "")
value_text = value_text.replace("\n ", characters[4])
node_string += value_text + characters[4]
# Write node majority class
if (self.class_names is not None and
tree.n_classes[0] != 1 and
tree.n_outputs == 1):
# Only done for single-output classification trees
if labels:
node_string += 'class = '
if self.class_names is not True:
class_name = self.class_names[np.argmax(value)]
else:
class_name = "y%s%s%s" % (characters[1],
np.argmax(value),
characters[2])
node_string += class_name
# Clean up any trailing newlines
if node_string.endswith(characters[4]):
node_string = node_string[:-len(characters[4])]
return node_string + characters[5]
class _DOTTreeExporter(_BaseTreeExporter):
def __init__(self, out_file=SENTINEL, max_depth=None,
feature_names=None, class_names=None, label='all',
filled=False, leaves_parallel=False, impurity=True,
node_ids=False, proportion=False, rotate=False, rounded=False,
special_characters=False, precision=3):
super().__init__(
max_depth=max_depth, feature_names=feature_names,
class_names=class_names, label=label, filled=filled,
impurity=impurity,
node_ids=node_ids, proportion=proportion, rotate=rotate,
rounded=rounded,
precision=precision)
self.leaves_parallel = leaves_parallel
self.out_file = out_file
self.special_characters = special_characters
# PostScript compatibility for special characters
if special_characters:
self.characters = ['&#35;', '<SUB>', '</SUB>', '&le;', '<br/>',
'>', '<']
else:
self.characters = ['#', '[', ']', '<=', '\\n', '"', '"']
# validate
if isinstance(precision, Integral):
if precision < 0:
raise ValueError("'precision' should be greater or equal to 0."
" Got {} instead.".format(precision))
else:
raise ValueError("'precision' should be an integer. Got {}"
" instead.".format(type(precision)))
# The depth of each node for plotting with 'leaf' option
self.ranks = {'leaves': []}
# The colors to render each node with
self.colors = {'bounds': None}
def export(self, decision_tree):
# Check length of feature_names before getting into the tree node
# Raise error if length of feature_names does not match
# n_features_ in the decision_tree
if self.feature_names is not None:
if len(self.feature_names) != decision_tree.n_features_:
raise ValueError("Length of feature_names, %d "
"does not match number of features, %d"
% (len(self.feature_names),
decision_tree.n_features_))
# each part writes to out_file
self.head()
# Now recurse the tree and add node & edge attributes
if isinstance(decision_tree, _tree.Tree):
self.recurse(decision_tree, 0, criterion="impurity")
else:
self.recurse(decision_tree.tree_, 0,
criterion=decision_tree.criterion)
self.tail()
def tail(self):
# If required, draw leaf nodes at same depth as each other
if self.leaves_parallel:
for rank in sorted(self.ranks):
self.out_file.write(
"{rank=same ; " +
"; ".join(r for r in self.ranks[rank]) + "} ;\n")
self.out_file.write("}")
def head(self):
self.out_file.write('digraph Tree {\n')
# Specify node aesthetics
self.out_file.write('node [shape=box')
rounded_filled = []
if self.filled:
rounded_filled.append('filled')
if self.rounded:
rounded_filled.append('rounded')
if len(rounded_filled) > 0:
self.out_file.write(
', style="%s", color="black"'
% ", ".join(rounded_filled))
if self.rounded:
self.out_file.write(', fontname=helvetica')
self.out_file.write('] ;\n')
# Specify graph & edge aesthetics
if self.leaves_parallel:
self.out_file.write(
'graph [ranksep=equally, splines=polyline] ;\n')
if self.rounded:
self.out_file.write('edge [fontname=helvetica] ;\n')
if self.rotate:
self.out_file.write('rankdir=LR ;\n')
def recurse(self, tree, node_id, criterion, parent=None, depth=0):
if node_id == _tree.TREE_LEAF:
raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF)
left_child = tree.children_left[node_id]
right_child = tree.children_right[node_id]
# Add node with description
if self.max_depth is None or depth <= self.max_depth:
# Collect ranks for 'leaf' option in plot_options
if left_child == _tree.TREE_LEAF:
self.ranks['leaves'].append(str(node_id))
elif str(depth) not in self.ranks:
self.ranks[str(depth)] = [str(node_id)]
else:
self.ranks[str(depth)].append(str(node_id))
self.out_file.write(
'%d [label=%s' % (node_id, self.node_to_str(tree, node_id,
criterion)))
if self.filled:
self.out_file.write(', fillcolor="%s"'
% self.get_fill_color(tree, node_id))
self.out_file.write('] ;\n')
if parent is not None:
# Add edge to parent
self.out_file.write('%d -> %d' % (parent, node_id))
if parent == 0:
# Draw True/False labels if parent is root node
angles = np.array([45, -45]) * ((self.rotate - .5) * -2)
self.out_file.write(' [labeldistance=2.5, labelangle=')
if node_id == 1:
self.out_file.write('%d, headlabel="True"]' %
angles[0])
else:
self.out_file.write('%d, headlabel="False"]' %
angles[1])
self.out_file.write(' ;\n')
if left_child != _tree.TREE_LEAF:
self.recurse(tree, left_child, criterion=criterion,
parent=node_id, depth=depth + 1)
self.recurse(tree, right_child, criterion=criterion,
parent=node_id, depth=depth + 1)
else:
self.ranks['leaves'].append(str(node_id))
self.out_file.write('%d [label="(...)"' % node_id)
if self.filled:
# color cropped nodes grey
self.out_file.write(', fillcolor="#C0C0C0"')
self.out_file.write('] ;\n' % node_id)
if parent is not None:
# Add edge to parent
self.out_file.write('%d -> %d ;\n' % (parent, node_id))
class _MPLTreeExporter(_BaseTreeExporter):
def __init__(self, max_depth=None, feature_names=None,
class_names=None, label='all', filled=False,
impurity=True, node_ids=False,
proportion=False, rotate=False, rounded=False,
precision=3, fontsize=None):
super().__init__(
max_depth=max_depth, feature_names=feature_names,
class_names=class_names, label=label, filled=filled,
impurity=impurity, node_ids=node_ids, proportion=proportion,
rotate=rotate, rounded=rounded, precision=precision)
self.fontsize = fontsize
# validate
if isinstance(precision, Integral):
if precision < 0:
raise ValueError("'precision' should be greater or equal to 0."
" Got {} instead.".format(precision))
else:
raise ValueError("'precision' should be an integer. Got {}"
" instead.".format(type(precision)))
# The depth of each node for plotting with 'leaf' option
self.ranks = {'leaves': []}
# The colors to render each node with
self.colors = {'bounds': None}
self.characters = ['#', '[', ']', '<=', '\n', '', '']
self.bbox_args = dict(fc='w')
if self.rounded:
self.bbox_args['boxstyle'] = "round"
self.arrow_args = dict(arrowstyle="<-")
def _make_tree(self, node_id, et, criterion, depth=0):
# traverses _tree.Tree recursively, builds intermediate
# "_reingold_tilford.Tree" object
name = self.node_to_str(et, node_id, criterion=criterion)
if (et.children_left[node_id] != _tree.TREE_LEAF
and (self.max_depth is None or depth <= self.max_depth)):
children = [self._make_tree(et.children_left[node_id], et,
criterion, depth=depth + 1),
self._make_tree(et.children_right[node_id], et,
criterion, depth=depth + 1)]
else:
return Tree(name, node_id)
return Tree(name, node_id, *children)
def export(self, decision_tree, ax=None):
import matplotlib.pyplot as plt
from matplotlib.text import Annotation
if ax is None:
ax = plt.gca()
ax.clear()
ax.set_axis_off()
my_tree = self._make_tree(0, decision_tree.tree_,
decision_tree.criterion)
draw_tree = buchheim(my_tree)
# important to make sure we're still
# inside the axis after drawing the box
# this makes sense because the width of a box
# is about the same as the distance between boxes
max_x, max_y = draw_tree.max_extents() + 1
ax_width = ax.get_window_extent().width
ax_height = ax.get_window_extent().height
scale_x = ax_width / max_x
scale_y = ax_height / max_y
self.recurse(draw_tree, decision_tree.tree_, ax,
scale_x, scale_y, ax_height)
anns = [ann for ann in ax.get_children()
if isinstance(ann, Annotation)]
# update sizes of all bboxes
renderer = ax.figure.canvas.get_renderer()
for ann in anns:
ann.update_bbox_position_size(renderer)
if self.fontsize is None:
# get figure to data transform
# adjust fontsize to avoid overlap
# get max box width and height
extents = [ann.get_bbox_patch().get_window_extent()
for ann in anns]
max_width = max([extent.width for extent in extents])
max_height = max([extent.height for extent in extents])
# width should be around scale_x in axis coordinates
size = anns[0].get_fontsize() * min(scale_x / max_width,
scale_y / max_height)
for ann in anns:
ann.set_fontsize(size)
return anns
def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0):
kwargs = dict(bbox=self.bbox_args, ha='center', va='center',
zorder=100 - 10 * depth, xycoords='axes pixels')
if self.fontsize is not None:
kwargs['fontsize'] = self.fontsize
# offset things by .5 to center them in plot
xy = ((node.x + .5) * scale_x, height - (node.y + .5) * scale_y)
if self.max_depth is None or depth <= self.max_depth:
if self.filled:
kwargs['bbox']['fc'] = self.get_fill_color(tree,
node.tree.node_id)
if node.parent is None:
# root
ax.annotate(node.tree.label, xy, **kwargs)
else:
xy_parent = ((node.parent.x + .5) * scale_x,
height - (node.parent.y + .5) * scale_y)
kwargs["arrowprops"] = self.arrow_args
ax.annotate(node.tree.label, xy_parent, xy, **kwargs)
for child in node.children:
self.recurse(child, tree, ax, scale_x, scale_y, height,
depth=depth + 1)
else:
xy_parent = ((node.parent.x + .5) * scale_x,
height - (node.parent.y + .5) * scale_y)
kwargs["arrowprops"] = self.arrow_args
kwargs['bbox']['fc'] = 'grey'
ax.annotate("\n (...) \n", xy_parent, xy, **kwargs)
@_deprecate_positional_args
def export_graphviz(decision_tree, out_file=None, *, max_depth=None,
feature_names=None, class_names=None, label='all',
filled=False, leaves_parallel=False, impurity=True,
node_ids=False, proportion=False, rotate=False,
rounded=False, special_characters=False, precision=3):
"""Export a decision tree in DOT format.
This function generates a GraphViz representation of the decision tree,
which is then written into `out_file`. Once exported, graphical renderings
can be generated using, for example::
$ dot -Tps tree.dot -o tree.ps (PostScript format)
$ dot -Tpng tree.dot -o tree.png (PNG format)
The sample counts that are shown are weighted with any sample_weights that
might be present.
Read more in the :ref:`User Guide <tree>`.
Parameters
----------
decision_tree : decision tree classifier
The decision tree to be exported to GraphViz.
out_file : file object or string, optional (default=None)
Handle or name of the output file. If ``None``, the result is
returned as a string.
.. versionchanged:: 0.20
Default of out_file changed from "tree.dot" to None.
max_depth : int, optional (default=None)
The maximum depth of the representation. If None, the tree is fully
generated.
feature_names : list of strings, optional (default=None)
Names of each of the features.
class_names : list of strings, bool or None, optional (default=None)
Names of each of the target classes in ascending numerical order.
Only relevant for classification and not supported for multi-output.
If ``True``, shows a symbolic representation of the class name.
label : {'all', 'root', 'none'}, optional (default='all')
Whether to show informative labels for impurity, etc.
Options include 'all' to show at every node, 'root' to show only at
the top root node, or 'none' to not show at any node.
filled : bool, optional (default=False)
When set to ``True``, paint nodes to indicate majority class for
classification, extremity of values for regression, or purity of node
for multi-output.
leaves_parallel : bool, optional (default=False)
When set to ``True``, draw all leaf nodes at the bottom of the tree.
impurity : bool, optional (default=True)
When set to ``True``, show the impurity at each node.
node_ids : bool, optional (default=False)
When set to ``True``, show the ID number on each node.
proportion : bool, optional (default=False)
When set to ``True``, change the display of 'values' and/or 'samples'
to be proportions and percentages respectively.
rotate : bool, optional (default=False)
When set to ``True``, orient tree left to right rather than top-down.
rounded : bool, optional (default=False)
When set to ``True``, draw node boxes with rounded corners and use
Helvetica fonts instead of Times-Roman.
special_characters : bool, optional (default=False)
When set to ``False``, ignore special characters for PostScript
compatibility.
precision : int, optional (default=3)
Number of digits of precision for floating point in the values of
impurity, threshold and value attributes of each node.
Returns
-------
dot_data : string
String representation of the input tree in GraphViz dot format.
Only returned if ``out_file`` is None.
.. versionadded:: 0.18
Examples
--------
>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>> clf = tree.DecisionTreeClassifier()
>>> iris = load_iris()
>>> clf = clf.fit(iris.data, iris.target)
>>> tree.export_graphviz(clf)
'digraph Tree {...
"""
check_is_fitted(decision_tree)
own_file = False
return_string = False
try:
if isinstance(out_file, str):
out_file = open(out_file, "w", encoding="utf-8")
own_file = True
if out_file is None:
return_string = True
out_file = StringIO()
exporter = _DOTTreeExporter(
out_file=out_file, max_depth=max_depth,
feature_names=feature_names, class_names=class_names, label=label,
filled=filled, leaves_parallel=leaves_parallel, impurity=impurity,
node_ids=node_ids, proportion=proportion, rotate=rotate,
rounded=rounded, special_characters=special_characters,
precision=precision)
exporter.export(decision_tree)
if return_string:
return exporter.out_file.getvalue()
finally:
if own_file:
out_file.close()
def _compute_depth(tree, node):
"""
Returns the depth of the subtree rooted in node.
"""
def compute_depth_(current_node, current_depth,
children_left, children_right, depths):
depths += [current_depth]
left = children_left[current_node]
right = children_right[current_node]
if left != -1 and right != -1:
compute_depth_(left, current_depth+1,
children_left, children_right, depths)
compute_depth_(right, current_depth+1,
children_left, children_right, depths)
depths = []
compute_depth_(node, 1, tree.children_left, tree.children_right, depths)
return max(depths)
@_deprecate_positional_args
def export_text(decision_tree, *, feature_names=None, max_depth=10,
spacing=3, decimals=2, show_weights=False):
"""Build a text report showing the rules of a decision tree.
Note that backwards compatibility may not be supported.
Parameters
----------
decision_tree : object
The decision tree estimator to be exported.
It can be an instance of
DecisionTreeClassifier or DecisionTreeRegressor.
feature_names : list, optional (default=None)
A list of length n_features containing the feature names.
If None generic names will be used ("feature_0", "feature_1", ...).
max_depth : int, optional (default=10)
Only the first max_depth levels of the tree are exported.
Truncated branches will be marked with "...".
spacing : int, optional (default=3)
Number of spaces between edges. The higher it is, the wider the result.
decimals : int, optional (default=2)
Number of decimal digits to display.
show_weights : bool, optional (default=False)
If true the classification weights will be exported on each leaf.
The classification weights are the number of samples each class.
Returns
-------
report : string
Text summary of all the rules in the decision tree.
Examples
--------
>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> from sklearn.tree import export_text
>>> iris = load_iris()
>>> X = iris['data']
>>> y = iris['target']
>>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
>>> decision_tree = decision_tree.fit(X, y)
>>> r = export_text(decision_tree, feature_names=iris['feature_names'])
>>> print(r)
|--- petal width (cm) <= 0.80
| |--- class: 0
|--- petal width (cm) > 0.80
| |--- petal width (cm) <= 1.75
| | |--- class: 1
| |--- petal width (cm) > 1.75
| | |--- class: 2
"""
check_is_fitted(decision_tree)
tree_ = decision_tree.tree_
if is_classifier(decision_tree):
class_names = decision_tree.classes_
right_child_fmt = "{} {} <= {}\n"
left_child_fmt = "{} {} > {}\n"
truncation_fmt = "{} {}\n"
if max_depth < 0:
raise ValueError("max_depth bust be >= 0, given %d" % max_depth)
if (feature_names is not None and
len(feature_names) != tree_.n_features):
raise ValueError("feature_names must contain "
"%d elements, got %d" % (tree_.n_features,
len(feature_names)))
if spacing <= 0:
raise ValueError("spacing must be > 0, given %d" % spacing)
if decimals < 0:
raise ValueError("decimals must be >= 0, given %d" % decimals)
if isinstance(decision_tree, DecisionTreeClassifier):
value_fmt = "{}{} weights: {}\n"
if not show_weights:
value_fmt = "{}{}{}\n"
else:
value_fmt = "{}{} value: {}\n"
if feature_names:
feature_names_ = [feature_names[i] if i != _tree.TREE_UNDEFINED
else None for i in tree_.feature]
else:
feature_names_ = ["feature_{}".format(i) for i in tree_.feature]
export_text.report = ""
def _add_leaf(value, class_name, indent):
val = ''
is_classification = isinstance(decision_tree,
DecisionTreeClassifier)
if show_weights or not is_classification:
val = ["{1:.{0}f}, ".format(decimals, v) for v in value]
val = '['+''.join(val)[:-2]+']'
if is_classification:
val += ' class: ' + str(class_name)
export_text.report += value_fmt.format(indent, '', val)
def print_tree_recurse(node, depth):
indent = ("|" + (" " * spacing)) * depth
indent = indent[:-spacing] + "-" * spacing
value = None
if tree_.n_outputs == 1:
value = tree_.value[node][0]
else:
value = tree_.value[node].T[0]
class_name = np.argmax(value)
if (tree_.n_classes[0] != 1 and
tree_.n_outputs == 1):
class_name = class_names[class_name]
if depth <= max_depth+1:
info_fmt = ""
info_fmt_left = info_fmt
info_fmt_right = info_fmt
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_names_[node]
threshold = tree_.threshold[node]
threshold = "{1:.{0}f}".format(decimals, threshold)
export_text.report += right_child_fmt.format(indent,
name,
threshold)
export_text.report += info_fmt_left
print_tree_recurse(tree_.children_left[node], depth+1)
export_text.report += left_child_fmt.format(indent,
name,
threshold)
export_text.report += info_fmt_right
print_tree_recurse(tree_.children_right[node], depth+1)
else: # leaf
_add_leaf(value, class_name, indent)
else:
subtree_depth = _compute_depth(tree_, node)
if subtree_depth == 1:
_add_leaf(value, class_name, indent)
else:
trunc_report = 'truncated branch of depth %d' % subtree_depth
export_text.report += truncation_fmt.format(indent,
trunc_report)
print_tree_recurse(0, 1)
return export_text.report

View file

@ -0,0 +1,188 @@
# Authors: William Mill (bill@billmill.org)
# License: BSD 3 clause
import numpy as np
class DrawTree:
def __init__(self, tree, parent=None, depth=0, number=1):
self.x = -1.
self.y = depth
self.tree = tree
self.children = [DrawTree(c, self, depth + 1, i + 1)
for i, c
in enumerate(tree.children)]
self.parent = parent
self.thread = None
self.mod = 0
self.ancestor = self
self.change = self.shift = 0
self._lmost_sibling = None
# this is the number of the node in its group of siblings 1..n
self.number = number
def left(self):
return self.thread or len(self.children) and self.children[0]
def right(self):
return self.thread or len(self.children) and self.children[-1]
def lbrother(self):
n = None
if self.parent:
for node in self.parent.children:
if node == self:
return n
else:
n = node
return n
def get_lmost_sibling(self):
if not self._lmost_sibling and self.parent and self != \
self.parent.children[0]:
self._lmost_sibling = self.parent.children[0]
return self._lmost_sibling
lmost_sibling = property(get_lmost_sibling)
def __str__(self):
return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod)
def __repr__(self):
return self.__str__()
def max_extents(self):
extents = [c.max_extents() for c in self. children]
extents.append((self.x, self.y))
return np.max(extents, axis=0)
def buchheim(tree):
dt = first_walk(DrawTree(tree))
min = second_walk(dt)
if min < 0:
third_walk(dt, -min)
return dt
def third_walk(tree, n):
tree.x += n
for c in tree.children:
third_walk(c, n)
def first_walk(v, distance=1.):
if len(v.children) == 0:
if v.lmost_sibling:
v.x = v.lbrother().x + distance
else:
v.x = 0.
else:
default_ancestor = v.children[0]
for w in v.children:
first_walk(w)
default_ancestor = apportion(w, default_ancestor, distance)
# print("finished v =", v.tree, "children")
execute_shifts(v)
midpoint = (v.children[0].x + v.children[-1].x) / 2
w = v.lbrother()
if w:
v.x = w.x + distance
v.mod = v.x - midpoint
else:
v.x = midpoint
return v
def apportion(v, default_ancestor, distance):
w = v.lbrother()
if w is not None:
# in buchheim notation:
# i == inner; o == outer; r == right; l == left; r = +; l = -
vir = vor = v
vil = w
vol = v.lmost_sibling
sir = sor = v.mod
sil = vil.mod
sol = vol.mod
while vil.right() and vir.left():
vil = vil.right()
vir = vir.left()
vol = vol.left()
vor = vor.right()
vor.ancestor = v
shift = (vil.x + sil) - (vir.x + sir) + distance
if shift > 0:
move_subtree(ancestor(vil, v, default_ancestor), v, shift)
sir = sir + shift
sor = sor + shift
sil += vil.mod
sir += vir.mod
sol += vol.mod
sor += vor.mod
if vil.right() and not vor.right():
vor.thread = vil.right()
vor.mod += sil - sor
else:
if vir.left() and not vol.left():
vol.thread = vir.left()
vol.mod += sir - sol
default_ancestor = v
return default_ancestor
def move_subtree(wl, wr, shift):
subtrees = wr.number - wl.number
# print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees,
# 'shift', shift)
# print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees
wr.change -= shift / subtrees
wr.shift += shift
wl.change += shift / subtrees
wr.x += shift
wr.mod += shift
def execute_shifts(v):
shift = change = 0
for w in v.children[::-1]:
# print("shift:", w, shift, w.change)
w.x += shift
w.mod += shift
change += w.change
shift += w.shift + change
def ancestor(vil, v, default_ancestor):
# the relevant text is at the bottom of page 7 of
# "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al,
# (2002)
# http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf
if vil.ancestor in v.parent.children:
return vil.ancestor
else:
return default_ancestor
def second_walk(v, m=0, depth=0, min=None):
v.x += m
v.y = depth
if min is None or v.x < min:
min = v.x
for w in v.children:
min = second_walk(w, m + v.mod, depth + 1, min)
return min
class Tree:
def __init__(self, label="", node_id=-1, *children):
self.label = label
self.node_id = node_id
if children:
self.children = children
else:
self.children = []

View file

@ -0,0 +1,94 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Brian Holt <bdholt1@gmail.com>
# Joel Nothman <joel.nothman@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
#
# License: BSD 3 clause
# See _splitter.pyx for details.
import numpy as np
cimport numpy as np
from ._criterion cimport Criterion
from ._tree cimport DTYPE_t # Type of X
from ._tree cimport DOUBLE_t # Type of y, sample_weight
from ._tree cimport SIZE_t # Type for indices and counters
from ._tree cimport INT32_t # Signed 32 bit integer
from ._tree cimport UINT32_t # Unsigned 32 bit integer
cdef struct SplitRecord:
# Data to track sample split
SIZE_t feature # Which feature to split on.
SIZE_t pos # Split samples array at the given position,
# i.e. count of samples below threshold for feature.
# pos is >= end if the node is a leaf.
double threshold # Threshold to split at.
double improvement # Impurity improvement given parent node.
double impurity_left # Impurity of the left split.
double impurity_right # Impurity of the right split.
cdef class Splitter:
# The splitter searches in the input space for a feature and a threshold
# to split the samples samples[start:end].
#
# The impurity computations are delegated to a criterion object.
# Internal structures
cdef public Criterion criterion # Impurity criterion
cdef public SIZE_t max_features # Number of features to test
cdef public SIZE_t min_samples_leaf # Min samples in a leaf
cdef public double min_weight_leaf # Minimum weight in a leaf
cdef object random_state # Random state
cdef UINT32_t rand_r_state # sklearn_rand_r random number state
cdef SIZE_t* samples # Sample indices in X, y
cdef SIZE_t n_samples # X.shape[0]
cdef double weighted_n_samples # Weighted number of samples
cdef SIZE_t* features # Feature indices in X
cdef SIZE_t* constant_features # Constant features indices
cdef SIZE_t n_features # X.shape[1]
cdef DTYPE_t* feature_values # temp. array holding feature values
cdef SIZE_t start # Start position for the current node
cdef SIZE_t end # End position for the current node
cdef const DOUBLE_t[:, ::1] y
cdef DOUBLE_t* sample_weight
# The samples vector `samples` is maintained by the Splitter object such
# that the samples contained in a node are contiguous. With this setting,
# `node_split` reorganizes the node samples `samples[start:end]` in two
# subsets `samples[start:pos]` and `samples[pos:end]`.
# The 1-d `features` array of size n_features contains the features
# indices and allows fast sampling without replacement of features.
# The 1-d `constant_features` array of size n_features holds in
# `constant_features[:n_constant_features]` the feature ids with
# constant values for all the samples that reached a specific node.
# The value `n_constant_features` is given by the parent node to its
# child nodes. The content of the range `[n_constant_features:]` is left
# undefined, but preallocated for performance reasons
# This allows optimization with depth-based tree building.
# Methods
cdef int init(self, object X, const DOUBLE_t[:, ::1] y,
DOUBLE_t* sample_weight,
np.ndarray X_idx_sorted=*) except -1
cdef int node_reset(self, SIZE_t start, SIZE_t end,
double* weighted_n_node_samples) nogil except -1
cdef int node_split(self,
double impurity, # Impurity of the node
SplitRecord* split,
SIZE_t* n_constant_features) nogil except -1
cdef void node_value(self, double* dest) nogil
cdef double node_impurity(self) nogil

View file

@ -0,0 +1,105 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Brian Holt <bdholt1@gmail.com>
# Joel Nothman <joel.nothman@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
#
# License: BSD 3 clause
# See _tree.pyx for details.
import numpy as np
cimport numpy as np
ctypedef np.npy_float32 DTYPE_t # Type of X
ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
ctypedef np.npy_intp SIZE_t # Type for indices and counters
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
from ._splitter cimport Splitter
from ._splitter cimport SplitRecord
cdef struct Node:
# Base storage structure for the nodes in a Tree object
SIZE_t left_child # id of the left child of the node
SIZE_t right_child # id of the right child of the node
SIZE_t feature # Feature used for splitting the node
DOUBLE_t threshold # Threshold value at the node
DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion)
SIZE_t n_node_samples # Number of samples at the node
DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node
cdef class Tree:
# The Tree object is a binary tree structure constructed by the
# TreeBuilder. The tree structure is used for predictions and
# feature importances.
# Input/Output layout
cdef public SIZE_t n_features # Number of features in X
cdef SIZE_t* n_classes # Number of classes in y[:, k]
cdef public SIZE_t n_outputs # Number of outputs in y
cdef public SIZE_t max_n_classes # max(n_classes)
# Inner structures: values are stored separately from node structure,
# since size is determined at runtime.
cdef public SIZE_t max_depth # Max depth of the tree
cdef public SIZE_t node_count # Counter for node IDs
cdef public SIZE_t capacity # Capacity of tree, in terms of nodes
cdef Node* nodes # Array of nodes
cdef double* value # (capacity, n_outputs, max_n_classes) array of values
cdef SIZE_t value_stride # = n_outputs * max_n_classes
# Methods
cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
SIZE_t feature, double threshold, double impurity,
SIZE_t n_node_samples,
double weighted_n_samples) nogil except -1
cdef int _resize(self, SIZE_t capacity) nogil except -1
cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1
cdef np.ndarray _get_value_ndarray(self)
cdef np.ndarray _get_node_ndarray(self)
cpdef np.ndarray predict(self, object X)
cpdef np.ndarray apply(self, object X)
cdef np.ndarray _apply_dense(self, object X)
cdef np.ndarray _apply_sparse_csr(self, object X)
cpdef object decision_path(self, object X)
cdef object _decision_path_dense(self, object X)
cdef object _decision_path_sparse_csr(self, object X)
cpdef compute_feature_importances(self, normalize=*)
# =============================================================================
# Tree builder
# =============================================================================
cdef class TreeBuilder:
# The TreeBuilder recursively builds a Tree object from training samples,
# using a Splitter object for splitting internal nodes and assigning
# values to leaves.
#
# This class controls the various stopping criteria and the node splitting
# evaluation order, e.g. depth-first or best-first.
cdef Splitter splitter # Splitting algorithm
cdef SIZE_t min_samples_split # Minimum number of samples in an internal node
cdef SIZE_t min_samples_leaf # Minimum number of samples in a leaf
cdef double min_weight_leaf # Minimum weight in a leaf
cdef SIZE_t max_depth # Maximal tree depth
cdef double min_impurity_split
cdef double min_impurity_decrease # Impurity threshold for early stopping
cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=*,
np.ndarray X_idx_sorted=*)
cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight)

View file

@ -0,0 +1,170 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
#
# License: BSD 3 clause
# See _utils.pyx for details.
import numpy as np
cimport numpy as np
from ._tree cimport Node
from ..neighbors._quad_tree cimport Cell
ctypedef np.npy_float32 DTYPE_t # Type of X
ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
ctypedef np.npy_intp SIZE_t # Type for indices and counters
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
cdef enum:
# Max value for our rand_r replacement (near the bottom).
# We don't use RAND_MAX because it's different across platforms and
# particularly tiny on Windows/MSVC.
RAND_R_MAX = 0x7FFFFFFF
# safe_realloc(&p, n) resizes the allocation of p to n * sizeof(*p) bytes or
# raises a MemoryError. It never calls free, since that's __dealloc__'s job.
# cdef DTYPE_t *p = NULL
# safe_realloc(&p, n)
# is equivalent to p = malloc(n * sizeof(*p)) with error checking.
ctypedef fused realloc_ptr:
# Add pointer types here as needed.
(DTYPE_t*)
(SIZE_t*)
(unsigned char*)
(WeightedPQueueRecord*)
(DOUBLE_t*)
(DOUBLE_t**)
(Node*)
(Cell*)
(Node**)
(StackRecord*)
(PriorityHeapRecord*)
cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except *
cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size)
cdef SIZE_t rand_int(SIZE_t low, SIZE_t high,
UINT32_t* random_state) nogil
cdef double rand_uniform(double low, double high,
UINT32_t* random_state) nogil
cdef double log(double x) nogil
# =============================================================================
# Stack data structure
# =============================================================================
# A record on the stack for depth-first tree growing
cdef struct StackRecord:
SIZE_t start
SIZE_t end
SIZE_t depth
SIZE_t parent
bint is_left
double impurity
SIZE_t n_constant_features
cdef class Stack:
cdef SIZE_t capacity
cdef SIZE_t top
cdef StackRecord* stack_
cdef bint is_empty(self) nogil
cdef int push(self, SIZE_t start, SIZE_t end, SIZE_t depth, SIZE_t parent,
bint is_left, double impurity,
SIZE_t n_constant_features) nogil except -1
cdef int pop(self, StackRecord* res) nogil
# =============================================================================
# PriorityHeap data structure
# =============================================================================
# A record on the frontier for best-first tree growing
cdef struct PriorityHeapRecord:
SIZE_t node_id
SIZE_t start
SIZE_t end
SIZE_t pos
SIZE_t depth
bint is_leaf
double impurity
double impurity_left
double impurity_right
double improvement
cdef class PriorityHeap:
cdef SIZE_t capacity
cdef SIZE_t heap_ptr
cdef PriorityHeapRecord* heap_
cdef bint is_empty(self) nogil
cdef void heapify_up(self, PriorityHeapRecord* heap, SIZE_t pos) nogil
cdef void heapify_down(self, PriorityHeapRecord* heap, SIZE_t pos, SIZE_t heap_length) nogil
cdef int push(self, SIZE_t node_id, SIZE_t start, SIZE_t end, SIZE_t pos,
SIZE_t depth, bint is_leaf, double improvement,
double impurity, double impurity_left,
double impurity_right) nogil except -1
cdef int pop(self, PriorityHeapRecord* res) nogil
# =============================================================================
# WeightedPQueue data structure
# =============================================================================
# A record stored in the WeightedPQueue
cdef struct WeightedPQueueRecord:
DOUBLE_t data
DOUBLE_t weight
cdef class WeightedPQueue:
cdef SIZE_t capacity
cdef SIZE_t array_ptr
cdef WeightedPQueueRecord* array_
cdef bint is_empty(self) nogil
cdef int reset(self) nogil except -1
cdef SIZE_t size(self) nogil
cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1
cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil
cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
cdef int peek(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
cdef DOUBLE_t get_weight_from_index(self, SIZE_t index) nogil
cdef DOUBLE_t get_value_from_index(self, SIZE_t index) nogil
# =============================================================================
# WeightedMedianCalculator data structure
# =============================================================================
cdef class WeightedMedianCalculator:
cdef SIZE_t initial_capacity
cdef WeightedPQueue samples
cdef DOUBLE_t total_weight
cdef SIZE_t k
cdef DOUBLE_t sum_w_0_k # represents sum(weights[0:k])
# = w[0] + w[1] + ... + w[k-1]
cdef SIZE_t size(self) nogil
cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1
cdef int reset(self) nogil except -1
cdef int update_median_parameters_post_push(
self, DOUBLE_t data, DOUBLE_t weight,
DOUBLE_t original_median) nogil
cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil
cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
cdef int update_median_parameters_post_remove(
self, DOUBLE_t data, DOUBLE_t weight,
DOUBLE_t original_median) nogil
cdef DOUBLE_t get_median(self) nogil

View file

@ -0,0 +1,18 @@
# THIS FILE WAS AUTOMATICALLY GENERATED BY deprecated_modules.py
import sys
# mypy error: Module X has no attribute y (typically for C extensions)
from . import _export # type: ignore
from ..externals._pep562 import Pep562
from ..utils.deprecation import _raise_dep_warning_if_not_pytest
deprecated_path = 'sklearn.tree.export'
correct_import_path = 'sklearn.tree'
_raise_dep_warning_if_not_pytest(deprecated_path, correct_import_path)
def __getattr__(name):
return getattr(_export, name)
if not sys.version_info >= (3, 7):
Pep562(__name__)

View file

@ -0,0 +1,39 @@
import os
import numpy
from numpy.distutils.misc_util import Configuration
def configuration(parent_package="", top_path=None):
config = Configuration("tree", parent_package, top_path)
libraries = []
if os.name == 'posix':
libraries.append('m')
config.add_extension("_tree",
sources=["_tree.pyx"],
include_dirs=[numpy.get_include()],
libraries=libraries,
extra_compile_args=["-O3"])
config.add_extension("_splitter",
sources=["_splitter.pyx"],
include_dirs=[numpy.get_include()],
libraries=libraries,
extra_compile_args=["-O3"])
config.add_extension("_criterion",
sources=["_criterion.pyx"],
include_dirs=[numpy.get_include()],
libraries=libraries,
extra_compile_args=["-O3"])
config.add_extension("_utils",
sources=["_utils.pyx"],
include_dirs=[numpy.get_include()],
libraries=libraries,
extra_compile_args=["-O3"])
config.add_subpackage("tests")
return config
if __name__ == "__main__":
from numpy.distutils.core import setup
setup(**configuration().todict())

View file

@ -0,0 +1,469 @@
"""
Testing for export functions of decision trees (sklearn.tree.export).
"""
from re import finditer, search
from textwrap import dedent
from numpy.random import RandomState
import pytest
from sklearn.base import is_classifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz, plot_tree, export_text
from io import StringIO
from sklearn.exceptions import NotFittedError
# toy sample
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
y = [-1, -1, -1, 1, 1, 1]
y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
w = [1, 1, 1, .5, .5, .5]
y_degraded = [1, 1, 1, 1, 1, 1]
def test_graphviz_toy():
# Check correctness of export_graphviz
clf = DecisionTreeClassifier(max_depth=3,
min_samples_split=2,
criterion="gini",
random_state=2)
clf.fit(X, y)
# Test export code
contents1 = export_graphviz(clf, out_file=None)
contents2 = 'digraph Tree {\n' \
'node [shape=box] ;\n' \
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
'value = [3, 3]"] ;\n' \
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
'headlabel="True"] ;\n' \
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' \
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
'headlabel="False"] ;\n' \
'}'
assert contents1 == contents2
# Test with feature_names
contents1 = export_graphviz(clf, feature_names=["feature0", "feature1"],
out_file=None)
contents2 = 'digraph Tree {\n' \
'node [shape=box] ;\n' \
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
'value = [3, 3]"] ;\n' \
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
'headlabel="True"] ;\n' \
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' \
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
'headlabel="False"] ;\n' \
'}'
assert contents1 == contents2
# Test with class_names
contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None)
contents2 = 'digraph Tree {\n' \
'node [shape=box] ;\n' \
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
'value = [3, 3]\\nclass = yes"] ;\n' \
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' \
'class = yes"] ;\n' \
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
'headlabel="True"] ;\n' \
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n' \
'class = no"] ;\n' \
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
'headlabel="False"] ;\n' \
'}'
assert contents1 == contents2
# Test plot_options
contents1 = export_graphviz(clf, filled=True, impurity=False,
proportion=True, special_characters=True,
rounded=True, out_file=None)
contents2 = 'digraph Tree {\n' \
'node [shape=box, style="filled, rounded", color="black", ' \
'fontname=helvetica] ;\n' \
'edge [fontname=helvetica] ;\n' \
'0 [label=<X<SUB>0</SUB> &le; 0.0<br/>samples = 100.0%<br/>' \
'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n' \
'1 [label=<samples = 50.0%<br/>value = [1.0, 0.0]>, ' \
'fillcolor="#e58139"] ;\n' \
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
'headlabel="True"] ;\n' \
'2 [label=<samples = 50.0%<br/>value = [0.0, 1.0]>, ' \
'fillcolor="#399de5"] ;\n' \
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
'headlabel="False"] ;\n' \
'}'
assert contents1 == contents2
# Test max_depth
contents1 = export_graphviz(clf, max_depth=0,
class_names=True, out_file=None)
contents2 = 'digraph Tree {\n' \
'node [shape=box] ;\n' \
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
'value = [3, 3]\\nclass = y[0]"] ;\n' \
'1 [label="(...)"] ;\n' \
'0 -> 1 ;\n' \
'2 [label="(...)"] ;\n' \
'0 -> 2 ;\n' \
'}'
assert contents1 == contents2
# Test max_depth with plot_options
contents1 = export_graphviz(clf, max_depth=0, filled=True,
out_file=None, node_ids=True)
contents2 = 'digraph Tree {\n' \
'node [shape=box, style="filled", color="black"] ;\n' \
'0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n' \
'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n' \
'1 [label="(...)", fillcolor="#C0C0C0"] ;\n' \
'0 -> 1 ;\n' \
'2 [label="(...)", fillcolor="#C0C0C0"] ;\n' \
'0 -> 2 ;\n' \
'}'
assert contents1 == contents2
# Test multi-output with weighted samples
clf = DecisionTreeClassifier(max_depth=2,
min_samples_split=2,
criterion="gini",
random_state=2)
clf = clf.fit(X, y2, sample_weight=w)
contents1 = export_graphviz(clf, filled=True,
impurity=False, out_file=None)
contents2 = 'digraph Tree {\n' \
'node [shape=box, style="filled", color="black"] ;\n' \
'0 [label="X[0] <= 0.0\\nsamples = 6\\n' \
'value = [[3.0, 1.5, 0.0]\\n' \
'[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n' \
'1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n' \
'[3, 0, 0]]", fillcolor="#e58139"] ;\n' \
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
'headlabel="True"] ;\n' \
'2 [label="X[0] <= 1.5\\nsamples = 3\\n' \
'value = [[0.0, 1.5, 0.0]\\n' \
'[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n' \
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
'headlabel="False"] ;\n' \
'3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n' \
'[0, 1, 0]]", fillcolor="#e58139"] ;\n' \
'2 -> 3 ;\n' \
'4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n' \
'[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n' \
'2 -> 4 ;\n' \
'}'
assert contents1 == contents2
# Test regression output with plot_options
clf = DecisionTreeRegressor(max_depth=3,
min_samples_split=2,
criterion="mse",
random_state=2)
clf.fit(X, y)
contents1 = export_graphviz(clf, filled=True, leaves_parallel=True,
out_file=None, rotate=True, rounded=True)
contents2 = 'digraph Tree {\n' \
'node [shape=box, style="filled, rounded", color="black", ' \
'fontname=helvetica] ;\n' \
'graph [ranksep=equally, splines=polyline] ;\n' \
'edge [fontname=helvetica] ;\n' \
'rankdir=LR ;\n' \
'0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \
'value = 0.0", fillcolor="#f2c09c"] ;\n' \
'1 [label="mse = 0.0\\nsamples = 3\\nvalue = -1.0", ' \
'fillcolor="#ffffff"] ;\n' \
'0 -> 1 [labeldistance=2.5, labelangle=-45, ' \
'headlabel="True"] ;\n' \
'2 [label="mse = 0.0\\nsamples = 3\\nvalue = 1.0", ' \
'fillcolor="#e58139"] ;\n' \
'0 -> 2 [labeldistance=2.5, labelangle=45, ' \
'headlabel="False"] ;\n' \
'{rank=same ; 0} ;\n' \
'{rank=same ; 1; 2} ;\n' \
'}'
assert contents1 == contents2
# Test classifier with degraded learning set
clf = DecisionTreeClassifier(max_depth=3)
clf.fit(X, y_degraded)
contents1 = export_graphviz(clf, filled=True, out_file=None)
contents2 = 'digraph Tree {\n' \
'node [shape=box, style="filled", color="black"] ;\n' \
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", ' \
'fillcolor="#ffffff"] ;\n' \
'}'
def test_graphviz_errors():
# Check for errors of export_graphviz
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
# Check not-fitted decision tree error
out = StringIO()
with pytest.raises(NotFittedError):
export_graphviz(clf, out)
clf.fit(X, y)
# Check if it errors when length of feature_names
# mismatches with number of features
message = ("Length of feature_names, "
"1 does not match number of features, 2")
with pytest.raises(ValueError, match=message):
export_graphviz(clf, None, feature_names=["a"])
message = ("Length of feature_names, "
"3 does not match number of features, 2")
with pytest.raises(ValueError, match=message):
export_graphviz(clf, None, feature_names=["a", "b", "c"])
# Check error when argument is not an estimator
message = "is not an estimator instance"
with pytest.raises(TypeError, match=message):
export_graphviz(clf.fit(X, y).tree_)
# Check class_names error
out = StringIO()
with pytest.raises(IndexError):
export_graphviz(clf, out, class_names=[])
# Check precision error
out = StringIO()
with pytest.raises(ValueError, match="should be greater or equal"):
export_graphviz(clf, out, precision=-1)
with pytest.raises(ValueError, match="should be an integer"):
export_graphviz(clf, out, precision="1")
def test_friedman_mse_in_graphviz():
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
clf.fit(X, y)
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data)
clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
clf.fit(X, y)
for estimator in clf.estimators_:
export_graphviz(estimator[0], out_file=dot_data)
for finding in finditer(r"\[.*?samples.*?\]", dot_data.getvalue()):
assert "friedman_mse" in finding.group()
def test_precision():
rng_reg = RandomState(2)
rng_clf = RandomState(8)
for X, y, clf in zip(
(rng_reg.random_sample((5, 2)),
rng_clf.random_sample((1000, 4))),
(rng_reg.random_sample((5, )),
rng_clf.randint(2, size=(1000, ))),
(DecisionTreeRegressor(criterion="friedman_mse", random_state=0,
max_depth=1),
DecisionTreeClassifier(max_depth=1, random_state=0))):
clf.fit(X, y)
for precision in (4, 3):
dot_data = export_graphviz(clf, out_file=None, precision=precision,
proportion=True)
# With the current random state, the impurity and the threshold
# will have the number of precision set in the export_graphviz
# function. We will check the number of precision with a strict
# equality. The value reported will have only 2 precision and
# therefore, only a less equal comparison will be done.
# check value
for finding in finditer(r"value = \d+\.\d+", dot_data):
assert (
len(search(r"\.\d+", finding.group()).group()) <=
precision + 1)
# check impurity
if is_classifier(clf):
pattern = r"gini = \d+\.\d+"
else:
pattern = r"friedman_mse = \d+\.\d+"
# check impurity
for finding in finditer(pattern, dot_data):
assert (len(search(r"\.\d+", finding.group()).group()) ==
precision + 1)
# check threshold
for finding in finditer(r"<= \d+\.\d+", dot_data):
assert (len(search(r"\.\d+", finding.group()).group()) ==
precision + 1)
def test_export_text_errors():
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X, y)
err_msg = "max_depth bust be >= 0, given -1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, max_depth=-1)
err_msg = "feature_names must contain 2 elements, got 1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, feature_names=['a'])
err_msg = "decimals must be >= 0, given -1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, decimals=-1)
err_msg = "spacing must be > 0, given 0"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, spacing=0)
def test_export_text():
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X, y)
expected_report = dedent("""
|--- feature_1 <= 0.00
| |--- class: -1
|--- feature_1 > 0.00
| |--- class: 1
""").lstrip()
assert export_text(clf) == expected_report
# testing that leaves at level 1 are not truncated
assert export_text(clf, max_depth=0) == expected_report
# testing that the rest of the tree is truncated
assert export_text(clf, max_depth=10) == expected_report
expected_report = dedent("""
|--- b <= 0.00
| |--- class: -1
|--- b > 0.00
| |--- class: 1
""").lstrip()
assert export_text(clf, feature_names=['a', 'b']) == expected_report
expected_report = dedent("""
|--- feature_1 <= 0.00
| |--- weights: [3.00, 0.00] class: -1
|--- feature_1 > 0.00
| |--- weights: [0.00, 3.00] class: 1
""").lstrip()
assert export_text(clf, show_weights=True) == expected_report
expected_report = dedent("""
|- feature_1 <= 0.00
| |- class: -1
|- feature_1 > 0.00
| |- class: 1
""").lstrip()
assert export_text(clf, spacing=1) == expected_report
X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]]
y_l = [-1, -1, -1, 1, 1, 1, 2]
clf = DecisionTreeClassifier(max_depth=4, random_state=0)
clf.fit(X_l, y_l)
expected_report = dedent("""
|--- feature_1 <= 0.00
| |--- class: -1
|--- feature_1 > 0.00
| |--- truncated branch of depth 2
""").lstrip()
assert export_text(clf, max_depth=0) == expected_report
X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]]
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
reg.fit(X_mo, y_mo)
expected_report = dedent("""
|--- feature_1 <= 0.0
| |--- value: [-1.0, -1.0]
|--- feature_1 > 0.0
| |--- value: [1.0, 1.0]
""").lstrip()
assert export_text(reg, decimals=1) == expected_report
assert export_text(reg, decimals=1, show_weights=True) == expected_report
X_single = [[-2], [-1], [-1], [1], [1], [2]]
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
reg.fit(X_single, y_mo)
expected_report = dedent("""
|--- first <= 0.0
| |--- value: [-1.0, -1.0]
|--- first > 0.0
| |--- value: [1.0, 1.0]
""").lstrip()
assert export_text(reg, decimals=1,
feature_names=['first']) == expected_report
assert export_text(reg, decimals=1, show_weights=True,
feature_names=['first']) == expected_report
def test_plot_tree_entropy(pyplot):
# mostly smoke tests
# Check correctness of export_graphviz for criterion = entropy
clf = DecisionTreeClassifier(max_depth=3,
min_samples_split=2,
criterion="entropy",
random_state=2)
clf.fit(X, y)
# Test export code
feature_names = ['first feat', 'sepal_width']
nodes = plot_tree(clf, feature_names=feature_names)
assert len(nodes) == 3
assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 1.0\n"
"samples = 6\nvalue = [3, 3]")
assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"
def test_plot_tree_gini(pyplot):
# mostly smoke tests
# Check correctness of export_graphviz for criterion = gini
clf = DecisionTreeClassifier(max_depth=3,
min_samples_split=2,
criterion="gini",
random_state=2)
clf.fit(X, y)
# Test export code
feature_names = ['first feat', 'sepal_width']
nodes = plot_tree(clf, feature_names=feature_names)
assert len(nodes) == 3
assert nodes[0].get_text() == ("first feat <= 0.0\ngini = 0.5\n"
"samples = 6\nvalue = [3, 3]")
assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]"
assert nodes[2].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"
# FIXME: to be removed in 0.25
def test_plot_tree_rotate_deprecation(pyplot):
tree = DecisionTreeClassifier()
tree.fit(X, y)
# test that a warning is raised when rotate is used.
match = ("'rotate' has no effect and is deprecated in 0.23. "
"It will be removed in 0.25.")
with pytest.warns(FutureWarning, match=match):
plot_tree(tree, rotate=True)
def test_not_fitted_tree(pyplot):
# Testing if not fitted tree throws the correct error
clf = DecisionTreeRegressor()
with pytest.raises(NotFittedError):
plot_tree(clf)

View file

@ -0,0 +1,52 @@
import numpy as np
import pytest
from sklearn.tree._reingold_tilford import buchheim, Tree
simple_tree = Tree("", 0,
Tree("", 1),
Tree("", 2))
bigger_tree = Tree("", 0,
Tree("", 1,
Tree("", 3),
Tree("", 4,
Tree("", 7),
Tree("", 8)
),
),
Tree("", 2,
Tree("", 5),
Tree("", 6)
)
)
@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)])
def test_buchheim(tree, n_nodes):
def walk_tree(draw_tree):
res = [(draw_tree.x, draw_tree.y)]
for child in draw_tree.children:
# parents higher than children:
assert child.y == draw_tree.y + 1
res.extend(walk_tree(child))
if len(draw_tree.children):
# these trees are always binary
# parents are centered above children
assert draw_tree.x == (draw_tree.children[0].x
+ draw_tree.children[1].x) / 2
return res
layout = buchheim(tree)
coordinates = walk_tree(layout)
assert len(coordinates) == n_nodes
# test that x values are unique per depth / level
# we could also do it quicker using defaultdicts..
depth = 0
while True:
x_at_this_depth = [node[0] for node in coordinates
if node[1] == depth]
if not x_at_this_depth:
# reached all leafs
break
assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth)
depth += 1

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,18 @@
# THIS FILE WAS AUTOMATICALLY GENERATED BY deprecated_modules.py
import sys
# mypy error: Module X has no attribute y (typically for C extensions)
from . import _classes # type: ignore
from ..externals._pep562 import Pep562
from ..utils.deprecation import _raise_dep_warning_if_not_pytest
deprecated_path = 'sklearn.tree.tree'
correct_import_path = 'sklearn.tree'
_raise_dep_warning_if_not_pytest(deprecated_path, correct_import_path)
def __getattr__(name):
return getattr(_classes, name)
if not sys.version_info >= (3, 7):
Pep562(__name__)