Uploaded Test files
This commit is contained in:
parent
f584ad9d97
commit
2e81cb7d99
16627 changed files with 2065359 additions and 102444 deletions
16
venv/Lib/site-packages/sklearn/tree/__init__.py
Normal file
16
venv/Lib/site-packages/sklearn/tree/__init__.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
"""
|
||||
The :mod:`sklearn.tree` module includes decision tree-based models for
|
||||
classification and regression.
|
||||
"""
|
||||
|
||||
from ._classes import BaseDecisionTree
|
||||
from ._classes import DecisionTreeClassifier
|
||||
from ._classes import DecisionTreeRegressor
|
||||
from ._classes import ExtraTreeClassifier
|
||||
from ._classes import ExtraTreeRegressor
|
||||
from ._export import export_graphviz, plot_tree, export_text
|
||||
|
||||
__all__ = ["BaseDecisionTree",
|
||||
"DecisionTreeClassifier", "DecisionTreeRegressor",
|
||||
"ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz",
|
||||
"plot_tree", "export_text"]
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1750
venv/Lib/site-packages/sklearn/tree/_classes.py
Normal file
1750
venv/Lib/site-packages/sklearn/tree/_classes.py
Normal file
File diff suppressed because it is too large
Load diff
BIN
venv/Lib/site-packages/sklearn/tree/_criterion.cp36-win32.pyd
Normal file
BIN
venv/Lib/site-packages/sklearn/tree/_criterion.cp36-win32.pyd
Normal file
Binary file not shown.
77
venv/Lib/site-packages/sklearn/tree/_criterion.pxd
Normal file
77
venv/Lib/site-packages/sklearn/tree/_criterion.pxd
Normal file
|
@ -0,0 +1,77 @@
|
|||
# Authors: Gilles Louppe <g.louppe@gmail.com>
|
||||
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
|
||||
# Brian Holt <bdholt1@gmail.com>
|
||||
# Joel Nothman <joel.nothman@gmail.com>
|
||||
# Arnaud Joly <arnaud.v.joly@gmail.com>
|
||||
# Jacob Schreiber <jmschreiber91@gmail.com>
|
||||
#
|
||||
# License: BSD 3 clause
|
||||
|
||||
# See _criterion.pyx for implementation details.
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
from ._tree cimport DTYPE_t # Type of X
|
||||
from ._tree cimport DOUBLE_t # Type of y, sample_weight
|
||||
from ._tree cimport SIZE_t # Type for indices and counters
|
||||
from ._tree cimport INT32_t # Signed 32 bit integer
|
||||
from ._tree cimport UINT32_t # Unsigned 32 bit integer
|
||||
|
||||
cdef class Criterion:
|
||||
# The criterion computes the impurity of a node and the reduction of
|
||||
# impurity of a split on that node. It also computes the output statistics
|
||||
# such as the mean in regression and class probabilities in classification.
|
||||
|
||||
# Internal structures
|
||||
cdef const DOUBLE_t[:, ::1] y # Values of y
|
||||
cdef DOUBLE_t* sample_weight # Sample weights
|
||||
|
||||
cdef SIZE_t* samples # Sample indices in X, y
|
||||
cdef SIZE_t start # samples[start:pos] are the samples in the left node
|
||||
cdef SIZE_t pos # samples[pos:end] are the samples in the right node
|
||||
cdef SIZE_t end
|
||||
|
||||
cdef SIZE_t n_outputs # Number of outputs
|
||||
cdef SIZE_t n_samples # Number of samples
|
||||
cdef SIZE_t n_node_samples # Number of samples in the node (end-start)
|
||||
cdef double weighted_n_samples # Weighted number of samples (in total)
|
||||
cdef double weighted_n_node_samples # Weighted number of samples in the node
|
||||
cdef double weighted_n_left # Weighted number of samples in the left node
|
||||
cdef double weighted_n_right # Weighted number of samples in the right node
|
||||
|
||||
cdef double* sum_total # For classification criteria, the sum of the
|
||||
# weighted count of each label. For regression,
|
||||
# the sum of w*y. sum_total[k] is equal to
|
||||
# sum_{i=start}^{end-1} w[samples[i]]*y[samples[i], k],
|
||||
# where k is output index.
|
||||
cdef double* sum_left # Same as above, but for the left side of the split
|
||||
cdef double* sum_right # same as above, but for the right side of the split
|
||||
|
||||
# The criterion object is maintained such that left and right collected
|
||||
# statistics correspond to samples[start:pos] and samples[pos:end].
|
||||
|
||||
# Methods
|
||||
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
|
||||
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
|
||||
SIZE_t end) nogil except -1
|
||||
cdef int reset(self) nogil except -1
|
||||
cdef int reverse_reset(self) nogil except -1
|
||||
cdef int update(self, SIZE_t new_pos) nogil except -1
|
||||
cdef double node_impurity(self) nogil
|
||||
cdef void children_impurity(self, double* impurity_left,
|
||||
double* impurity_right) nogil
|
||||
cdef void node_value(self, double* dest) nogil
|
||||
cdef double impurity_improvement(self, double impurity) nogil
|
||||
cdef double proxy_impurity_improvement(self) nogil
|
||||
|
||||
cdef class ClassificationCriterion(Criterion):
|
||||
"""Abstract criterion for classification."""
|
||||
|
||||
cdef SIZE_t* n_classes
|
||||
cdef SIZE_t sum_stride
|
||||
|
||||
cdef class RegressionCriterion(Criterion):
|
||||
"""Abstract regression criterion."""
|
||||
|
||||
cdef double sq_sum_total
|
967
venv/Lib/site-packages/sklearn/tree/_export.py
Normal file
967
venv/Lib/site-packages/sklearn/tree/_export.py
Normal file
|
@ -0,0 +1,967 @@
|
|||
"""
|
||||
This module defines export functions for decision trees.
|
||||
"""
|
||||
|
||||
# Authors: Gilles Louppe <g.louppe@gmail.com>
|
||||
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
|
||||
# Brian Holt <bdholt1@gmail.com>
|
||||
# Noel Dawe <noel@dawe.me>
|
||||
# Satrajit Gosh <satrajit.ghosh@gmail.com>
|
||||
# Trevor Stephens <trev.stephens@gmail.com>
|
||||
# Li Li <aiki.nogard@gmail.com>
|
||||
# Giuseppe Vettigli <vettigli@gmail.com>
|
||||
# License: BSD 3 clause
|
||||
from io import StringIO
|
||||
from numbers import Integral
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils.validation import check_is_fitted
|
||||
from ..utils.validation import _deprecate_positional_args
|
||||
from ..base import is_classifier
|
||||
|
||||
from . import _criterion
|
||||
from . import _tree
|
||||
from ._reingold_tilford import buchheim, Tree
|
||||
from . import DecisionTreeClassifier
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
def _color_brew(n):
|
||||
"""Generate n colors with equally spaced hues.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n : int
|
||||
The number of colors required.
|
||||
|
||||
Returns
|
||||
-------
|
||||
color_list : list, length n
|
||||
List of n tuples of form (R, G, B) being the components of each color.
|
||||
"""
|
||||
color_list = []
|
||||
|
||||
# Initialize saturation & value; calculate chroma & value shift
|
||||
s, v = 0.75, 0.9
|
||||
c = s * v
|
||||
m = v - c
|
||||
|
||||
for h in np.arange(25, 385, 360. / n).astype(int):
|
||||
# Calculate some intermediate values
|
||||
h_bar = h / 60.
|
||||
x = c * (1 - abs((h_bar % 2) - 1))
|
||||
# Initialize RGB with same hue & chroma as our color
|
||||
rgb = [(c, x, 0),
|
||||
(x, c, 0),
|
||||
(0, c, x),
|
||||
(0, x, c),
|
||||
(x, 0, c),
|
||||
(c, 0, x),
|
||||
(c, x, 0)]
|
||||
r, g, b = rgb[int(h_bar)]
|
||||
# Shift the initial RGB values to match value and store
|
||||
rgb = [(int(255 * (r + m))),
|
||||
(int(255 * (g + m))),
|
||||
(int(255 * (b + m)))]
|
||||
color_list.append(rgb)
|
||||
|
||||
return color_list
|
||||
|
||||
|
||||
class Sentinel:
|
||||
def __repr__(self):
|
||||
return '"tree.dot"'
|
||||
|
||||
|
||||
SENTINEL = Sentinel()
|
||||
|
||||
|
||||
@_deprecate_positional_args
|
||||
def plot_tree(decision_tree, *, max_depth=None, feature_names=None,
|
||||
class_names=None, label='all', filled=False,
|
||||
impurity=True, node_ids=False,
|
||||
proportion=False, rotate='deprecated', rounded=False,
|
||||
precision=3, ax=None, fontsize=None):
|
||||
"""Plot a decision tree.
|
||||
|
||||
The sample counts that are shown are weighted with any sample_weights that
|
||||
might be present.
|
||||
|
||||
The visualization is fit automatically to the size of the axis.
|
||||
Use the ``figsize`` or ``dpi`` arguments of ``plt.figure`` to control
|
||||
the size of the rendering.
|
||||
|
||||
Read more in the :ref:`User Guide <tree>`.
|
||||
|
||||
.. versionadded:: 0.21
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decision_tree : decision tree regressor or classifier
|
||||
The decision tree to be plotted.
|
||||
|
||||
max_depth : int, optional (default=None)
|
||||
The maximum depth of the representation. If None, the tree is fully
|
||||
generated.
|
||||
|
||||
feature_names : list of strings, optional (default=None)
|
||||
Names of each of the features.
|
||||
|
||||
class_names : list of strings, bool or None, optional (default=None)
|
||||
Names of each of the target classes in ascending numerical order.
|
||||
Only relevant for classification and not supported for multi-output.
|
||||
If ``True``, shows a symbolic representation of the class name.
|
||||
|
||||
label : {'all', 'root', 'none'}, optional (default='all')
|
||||
Whether to show informative labels for impurity, etc.
|
||||
Options include 'all' to show at every node, 'root' to show only at
|
||||
the top root node, or 'none' to not show at any node.
|
||||
|
||||
filled : bool, optional (default=False)
|
||||
When set to ``True``, paint nodes to indicate majority class for
|
||||
classification, extremity of values for regression, or purity of node
|
||||
for multi-output.
|
||||
|
||||
impurity : bool, optional (default=True)
|
||||
When set to ``True``, show the impurity at each node.
|
||||
|
||||
node_ids : bool, optional (default=False)
|
||||
When set to ``True``, show the ID number on each node.
|
||||
|
||||
proportion : bool, optional (default=False)
|
||||
When set to ``True``, change the display of 'values' and/or 'samples'
|
||||
to be proportions and percentages respectively.
|
||||
|
||||
rotate : bool, optional (default=False)
|
||||
This parameter has no effect on the matplotlib tree visualisation and
|
||||
it is kept here for backward compatibility.
|
||||
|
||||
.. deprecated:: 0.23
|
||||
``rotate`` is deprecated in 0.23 and will be removed in 0.25.
|
||||
|
||||
|
||||
rounded : bool, optional (default=False)
|
||||
When set to ``True``, draw node boxes with rounded corners and use
|
||||
Helvetica fonts instead of Times-Roman.
|
||||
|
||||
precision : int, optional (default=3)
|
||||
Number of digits of precision for floating point in the values of
|
||||
impurity, threshold and value attributes of each node.
|
||||
|
||||
ax : matplotlib axis, optional (default=None)
|
||||
Axes to plot to. If None, use current axis. Any previous content
|
||||
is cleared.
|
||||
|
||||
fontsize : int, optional (default=None)
|
||||
Size of text font. If None, determined automatically to fit figure.
|
||||
|
||||
Returns
|
||||
-------
|
||||
annotations : list of artists
|
||||
List containing the artists for the annotation boxes making up the
|
||||
tree.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.datasets import load_iris
|
||||
>>> from sklearn import tree
|
||||
|
||||
>>> clf = tree.DecisionTreeClassifier(random_state=0)
|
||||
>>> iris = load_iris()
|
||||
|
||||
>>> clf = clf.fit(iris.data, iris.target)
|
||||
>>> tree.plot_tree(clf) # doctest: +SKIP
|
||||
[Text(251.5,345.217,'X[3] <= 0.8...
|
||||
|
||||
"""
|
||||
|
||||
check_is_fitted(decision_tree)
|
||||
|
||||
if rotate != 'deprecated':
|
||||
warnings.warn(("'rotate' has no effect and is deprecated in 0.23. "
|
||||
"It will be removed in 0.25."),
|
||||
FutureWarning)
|
||||
|
||||
exporter = _MPLTreeExporter(
|
||||
max_depth=max_depth, feature_names=feature_names,
|
||||
class_names=class_names, label=label, filled=filled,
|
||||
impurity=impurity, node_ids=node_ids,
|
||||
proportion=proportion, rotate=rotate, rounded=rounded,
|
||||
precision=precision, fontsize=fontsize)
|
||||
return exporter.export(decision_tree, ax=ax)
|
||||
|
||||
|
||||
class _BaseTreeExporter:
|
||||
def __init__(self, max_depth=None, feature_names=None,
|
||||
class_names=None, label='all', filled=False,
|
||||
impurity=True, node_ids=False,
|
||||
proportion=False, rotate=False, rounded=False,
|
||||
precision=3, fontsize=None):
|
||||
self.max_depth = max_depth
|
||||
self.feature_names = feature_names
|
||||
self.class_names = class_names
|
||||
self.label = label
|
||||
self.filled = filled
|
||||
self.impurity = impurity
|
||||
self.node_ids = node_ids
|
||||
self.proportion = proportion
|
||||
self.rotate = rotate
|
||||
self.rounded = rounded
|
||||
self.precision = precision
|
||||
self.fontsize = fontsize
|
||||
|
||||
def get_color(self, value):
|
||||
# Find the appropriate color & intensity for a node
|
||||
if self.colors['bounds'] is None:
|
||||
# Classification tree
|
||||
color = list(self.colors['rgb'][np.argmax(value)])
|
||||
sorted_values = sorted(value, reverse=True)
|
||||
if len(sorted_values) == 1:
|
||||
alpha = 0
|
||||
else:
|
||||
alpha = ((sorted_values[0] - sorted_values[1])
|
||||
/ (1 - sorted_values[1]))
|
||||
else:
|
||||
# Regression tree or multi-output
|
||||
color = list(self.colors['rgb'][0])
|
||||
alpha = ((value - self.colors['bounds'][0]) /
|
||||
(self.colors['bounds'][1] - self.colors['bounds'][0]))
|
||||
# unpack numpy scalars
|
||||
alpha = float(alpha)
|
||||
# compute the color as alpha against white
|
||||
color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color]
|
||||
# Return html color code in #RRGGBB format
|
||||
return '#%2x%2x%2x' % tuple(color)
|
||||
|
||||
def get_fill_color(self, tree, node_id):
|
||||
# Fetch appropriate color for node
|
||||
if 'rgb' not in self.colors:
|
||||
# Initialize colors and bounds if required
|
||||
self.colors['rgb'] = _color_brew(tree.n_classes[0])
|
||||
if tree.n_outputs != 1:
|
||||
# Find max and min impurities for multi-output
|
||||
self.colors['bounds'] = (np.min(-tree.impurity),
|
||||
np.max(-tree.impurity))
|
||||
elif (tree.n_classes[0] == 1 and
|
||||
len(np.unique(tree.value)) != 1):
|
||||
# Find max and min values in leaf nodes for regression
|
||||
self.colors['bounds'] = (np.min(tree.value),
|
||||
np.max(tree.value))
|
||||
if tree.n_outputs == 1:
|
||||
node_val = (tree.value[node_id][0, :] /
|
||||
tree.weighted_n_node_samples[node_id])
|
||||
if tree.n_classes[0] == 1:
|
||||
# Regression
|
||||
node_val = tree.value[node_id][0, :]
|
||||
else:
|
||||
# If multi-output color node by impurity
|
||||
node_val = -tree.impurity[node_id]
|
||||
return self.get_color(node_val)
|
||||
|
||||
def node_to_str(self, tree, node_id, criterion):
|
||||
# Generate the node content string
|
||||
if tree.n_outputs == 1:
|
||||
value = tree.value[node_id][0, :]
|
||||
else:
|
||||
value = tree.value[node_id]
|
||||
|
||||
# Should labels be shown?
|
||||
labels = (self.label == 'root' and node_id == 0) or self.label == 'all'
|
||||
|
||||
characters = self.characters
|
||||
node_string = characters[-1]
|
||||
|
||||
# Write node ID
|
||||
if self.node_ids:
|
||||
if labels:
|
||||
node_string += 'node '
|
||||
node_string += characters[0] + str(node_id) + characters[4]
|
||||
|
||||
# Write decision criteria
|
||||
if tree.children_left[node_id] != _tree.TREE_LEAF:
|
||||
# Always write node decision criteria, except for leaves
|
||||
if self.feature_names is not None:
|
||||
feature = self.feature_names[tree.feature[node_id]]
|
||||
else:
|
||||
feature = "X%s%s%s" % (characters[1],
|
||||
tree.feature[node_id],
|
||||
characters[2])
|
||||
node_string += '%s %s %s%s' % (feature,
|
||||
characters[3],
|
||||
round(tree.threshold[node_id],
|
||||
self.precision),
|
||||
characters[4])
|
||||
|
||||
# Write impurity
|
||||
if self.impurity:
|
||||
if isinstance(criterion, _criterion.FriedmanMSE):
|
||||
criterion = "friedman_mse"
|
||||
elif not isinstance(criterion, str):
|
||||
criterion = "impurity"
|
||||
if labels:
|
||||
node_string += '%s = ' % criterion
|
||||
node_string += (str(round(tree.impurity[node_id], self.precision))
|
||||
+ characters[4])
|
||||
|
||||
# Write node sample count
|
||||
if labels:
|
||||
node_string += 'samples = '
|
||||
if self.proportion:
|
||||
percent = (100. * tree.n_node_samples[node_id] /
|
||||
float(tree.n_node_samples[0]))
|
||||
node_string += (str(round(percent, 1)) + '%' +
|
||||
characters[4])
|
||||
else:
|
||||
node_string += (str(tree.n_node_samples[node_id]) +
|
||||
characters[4])
|
||||
|
||||
# Write node class distribution / regression value
|
||||
if self.proportion and tree.n_classes[0] != 1:
|
||||
# For classification this will show the proportion of samples
|
||||
value = value / tree.weighted_n_node_samples[node_id]
|
||||
if labels:
|
||||
node_string += 'value = '
|
||||
if tree.n_classes[0] == 1:
|
||||
# Regression
|
||||
value_text = np.around(value, self.precision)
|
||||
elif self.proportion:
|
||||
# Classification
|
||||
value_text = np.around(value, self.precision)
|
||||
elif np.all(np.equal(np.mod(value, 1), 0)):
|
||||
# Classification without floating-point weights
|
||||
value_text = value.astype(int)
|
||||
else:
|
||||
# Classification with floating-point weights
|
||||
value_text = np.around(value, self.precision)
|
||||
# Strip whitespace
|
||||
value_text = str(value_text.astype('S32')).replace("b'", "'")
|
||||
value_text = value_text.replace("' '", ", ").replace("'", "")
|
||||
if tree.n_classes[0] == 1 and tree.n_outputs == 1:
|
||||
value_text = value_text.replace("[", "").replace("]", "")
|
||||
value_text = value_text.replace("\n ", characters[4])
|
||||
node_string += value_text + characters[4]
|
||||
|
||||
# Write node majority class
|
||||
if (self.class_names is not None and
|
||||
tree.n_classes[0] != 1 and
|
||||
tree.n_outputs == 1):
|
||||
# Only done for single-output classification trees
|
||||
if labels:
|
||||
node_string += 'class = '
|
||||
if self.class_names is not True:
|
||||
class_name = self.class_names[np.argmax(value)]
|
||||
else:
|
||||
class_name = "y%s%s%s" % (characters[1],
|
||||
np.argmax(value),
|
||||
characters[2])
|
||||
node_string += class_name
|
||||
|
||||
# Clean up any trailing newlines
|
||||
if node_string.endswith(characters[4]):
|
||||
node_string = node_string[:-len(characters[4])]
|
||||
|
||||
return node_string + characters[5]
|
||||
|
||||
|
||||
class _DOTTreeExporter(_BaseTreeExporter):
|
||||
def __init__(self, out_file=SENTINEL, max_depth=None,
|
||||
feature_names=None, class_names=None, label='all',
|
||||
filled=False, leaves_parallel=False, impurity=True,
|
||||
node_ids=False, proportion=False, rotate=False, rounded=False,
|
||||
special_characters=False, precision=3):
|
||||
|
||||
super().__init__(
|
||||
max_depth=max_depth, feature_names=feature_names,
|
||||
class_names=class_names, label=label, filled=filled,
|
||||
impurity=impurity,
|
||||
node_ids=node_ids, proportion=proportion, rotate=rotate,
|
||||
rounded=rounded,
|
||||
precision=precision)
|
||||
self.leaves_parallel = leaves_parallel
|
||||
self.out_file = out_file
|
||||
self.special_characters = special_characters
|
||||
|
||||
# PostScript compatibility for special characters
|
||||
if special_characters:
|
||||
self.characters = ['#', '<SUB>', '</SUB>', '≤', '<br/>',
|
||||
'>', '<']
|
||||
else:
|
||||
self.characters = ['#', '[', ']', '<=', '\\n', '"', '"']
|
||||
|
||||
# validate
|
||||
if isinstance(precision, Integral):
|
||||
if precision < 0:
|
||||
raise ValueError("'precision' should be greater or equal to 0."
|
||||
" Got {} instead.".format(precision))
|
||||
else:
|
||||
raise ValueError("'precision' should be an integer. Got {}"
|
||||
" instead.".format(type(precision)))
|
||||
|
||||
# The depth of each node for plotting with 'leaf' option
|
||||
self.ranks = {'leaves': []}
|
||||
# The colors to render each node with
|
||||
self.colors = {'bounds': None}
|
||||
|
||||
def export(self, decision_tree):
|
||||
# Check length of feature_names before getting into the tree node
|
||||
# Raise error if length of feature_names does not match
|
||||
# n_features_ in the decision_tree
|
||||
if self.feature_names is not None:
|
||||
if len(self.feature_names) != decision_tree.n_features_:
|
||||
raise ValueError("Length of feature_names, %d "
|
||||
"does not match number of features, %d"
|
||||
% (len(self.feature_names),
|
||||
decision_tree.n_features_))
|
||||
# each part writes to out_file
|
||||
self.head()
|
||||
# Now recurse the tree and add node & edge attributes
|
||||
if isinstance(decision_tree, _tree.Tree):
|
||||
self.recurse(decision_tree, 0, criterion="impurity")
|
||||
else:
|
||||
self.recurse(decision_tree.tree_, 0,
|
||||
criterion=decision_tree.criterion)
|
||||
|
||||
self.tail()
|
||||
|
||||
def tail(self):
|
||||
# If required, draw leaf nodes at same depth as each other
|
||||
if self.leaves_parallel:
|
||||
for rank in sorted(self.ranks):
|
||||
self.out_file.write(
|
||||
"{rank=same ; " +
|
||||
"; ".join(r for r in self.ranks[rank]) + "} ;\n")
|
||||
self.out_file.write("}")
|
||||
|
||||
def head(self):
|
||||
self.out_file.write('digraph Tree {\n')
|
||||
|
||||
# Specify node aesthetics
|
||||
self.out_file.write('node [shape=box')
|
||||
rounded_filled = []
|
||||
if self.filled:
|
||||
rounded_filled.append('filled')
|
||||
if self.rounded:
|
||||
rounded_filled.append('rounded')
|
||||
if len(rounded_filled) > 0:
|
||||
self.out_file.write(
|
||||
', style="%s", color="black"'
|
||||
% ", ".join(rounded_filled))
|
||||
if self.rounded:
|
||||
self.out_file.write(', fontname=helvetica')
|
||||
self.out_file.write('] ;\n')
|
||||
|
||||
# Specify graph & edge aesthetics
|
||||
if self.leaves_parallel:
|
||||
self.out_file.write(
|
||||
'graph [ranksep=equally, splines=polyline] ;\n')
|
||||
if self.rounded:
|
||||
self.out_file.write('edge [fontname=helvetica] ;\n')
|
||||
if self.rotate:
|
||||
self.out_file.write('rankdir=LR ;\n')
|
||||
|
||||
def recurse(self, tree, node_id, criterion, parent=None, depth=0):
|
||||
if node_id == _tree.TREE_LEAF:
|
||||
raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF)
|
||||
|
||||
left_child = tree.children_left[node_id]
|
||||
right_child = tree.children_right[node_id]
|
||||
|
||||
# Add node with description
|
||||
if self.max_depth is None or depth <= self.max_depth:
|
||||
|
||||
# Collect ranks for 'leaf' option in plot_options
|
||||
if left_child == _tree.TREE_LEAF:
|
||||
self.ranks['leaves'].append(str(node_id))
|
||||
elif str(depth) not in self.ranks:
|
||||
self.ranks[str(depth)] = [str(node_id)]
|
||||
else:
|
||||
self.ranks[str(depth)].append(str(node_id))
|
||||
|
||||
self.out_file.write(
|
||||
'%d [label=%s' % (node_id, self.node_to_str(tree, node_id,
|
||||
criterion)))
|
||||
|
||||
if self.filled:
|
||||
self.out_file.write(', fillcolor="%s"'
|
||||
% self.get_fill_color(tree, node_id))
|
||||
self.out_file.write('] ;\n')
|
||||
|
||||
if parent is not None:
|
||||
# Add edge to parent
|
||||
self.out_file.write('%d -> %d' % (parent, node_id))
|
||||
if parent == 0:
|
||||
# Draw True/False labels if parent is root node
|
||||
angles = np.array([45, -45]) * ((self.rotate - .5) * -2)
|
||||
self.out_file.write(' [labeldistance=2.5, labelangle=')
|
||||
if node_id == 1:
|
||||
self.out_file.write('%d, headlabel="True"]' %
|
||||
angles[0])
|
||||
else:
|
||||
self.out_file.write('%d, headlabel="False"]' %
|
||||
angles[1])
|
||||
self.out_file.write(' ;\n')
|
||||
|
||||
if left_child != _tree.TREE_LEAF:
|
||||
self.recurse(tree, left_child, criterion=criterion,
|
||||
parent=node_id, depth=depth + 1)
|
||||
self.recurse(tree, right_child, criterion=criterion,
|
||||
parent=node_id, depth=depth + 1)
|
||||
|
||||
else:
|
||||
self.ranks['leaves'].append(str(node_id))
|
||||
|
||||
self.out_file.write('%d [label="(...)"' % node_id)
|
||||
if self.filled:
|
||||
# color cropped nodes grey
|
||||
self.out_file.write(', fillcolor="#C0C0C0"')
|
||||
self.out_file.write('] ;\n' % node_id)
|
||||
|
||||
if parent is not None:
|
||||
# Add edge to parent
|
||||
self.out_file.write('%d -> %d ;\n' % (parent, node_id))
|
||||
|
||||
|
||||
class _MPLTreeExporter(_BaseTreeExporter):
|
||||
def __init__(self, max_depth=None, feature_names=None,
|
||||
class_names=None, label='all', filled=False,
|
||||
impurity=True, node_ids=False,
|
||||
proportion=False, rotate=False, rounded=False,
|
||||
precision=3, fontsize=None):
|
||||
|
||||
super().__init__(
|
||||
max_depth=max_depth, feature_names=feature_names,
|
||||
class_names=class_names, label=label, filled=filled,
|
||||
impurity=impurity, node_ids=node_ids, proportion=proportion,
|
||||
rotate=rotate, rounded=rounded, precision=precision)
|
||||
self.fontsize = fontsize
|
||||
|
||||
# validate
|
||||
if isinstance(precision, Integral):
|
||||
if precision < 0:
|
||||
raise ValueError("'precision' should be greater or equal to 0."
|
||||
" Got {} instead.".format(precision))
|
||||
else:
|
||||
raise ValueError("'precision' should be an integer. Got {}"
|
||||
" instead.".format(type(precision)))
|
||||
|
||||
# The depth of each node for plotting with 'leaf' option
|
||||
self.ranks = {'leaves': []}
|
||||
# The colors to render each node with
|
||||
self.colors = {'bounds': None}
|
||||
|
||||
self.characters = ['#', '[', ']', '<=', '\n', '', '']
|
||||
|
||||
self.bbox_args = dict(fc='w')
|
||||
if self.rounded:
|
||||
self.bbox_args['boxstyle'] = "round"
|
||||
|
||||
self.arrow_args = dict(arrowstyle="<-")
|
||||
|
||||
def _make_tree(self, node_id, et, criterion, depth=0):
|
||||
# traverses _tree.Tree recursively, builds intermediate
|
||||
# "_reingold_tilford.Tree" object
|
||||
name = self.node_to_str(et, node_id, criterion=criterion)
|
||||
if (et.children_left[node_id] != _tree.TREE_LEAF
|
||||
and (self.max_depth is None or depth <= self.max_depth)):
|
||||
children = [self._make_tree(et.children_left[node_id], et,
|
||||
criterion, depth=depth + 1),
|
||||
self._make_tree(et.children_right[node_id], et,
|
||||
criterion, depth=depth + 1)]
|
||||
else:
|
||||
return Tree(name, node_id)
|
||||
return Tree(name, node_id, *children)
|
||||
|
||||
def export(self, decision_tree, ax=None):
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.text import Annotation
|
||||
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
ax.clear()
|
||||
ax.set_axis_off()
|
||||
my_tree = self._make_tree(0, decision_tree.tree_,
|
||||
decision_tree.criterion)
|
||||
draw_tree = buchheim(my_tree)
|
||||
|
||||
# important to make sure we're still
|
||||
# inside the axis after drawing the box
|
||||
# this makes sense because the width of a box
|
||||
# is about the same as the distance between boxes
|
||||
max_x, max_y = draw_tree.max_extents() + 1
|
||||
ax_width = ax.get_window_extent().width
|
||||
ax_height = ax.get_window_extent().height
|
||||
|
||||
scale_x = ax_width / max_x
|
||||
scale_y = ax_height / max_y
|
||||
|
||||
self.recurse(draw_tree, decision_tree.tree_, ax,
|
||||
scale_x, scale_y, ax_height)
|
||||
|
||||
anns = [ann for ann in ax.get_children()
|
||||
if isinstance(ann, Annotation)]
|
||||
|
||||
# update sizes of all bboxes
|
||||
renderer = ax.figure.canvas.get_renderer()
|
||||
|
||||
for ann in anns:
|
||||
ann.update_bbox_position_size(renderer)
|
||||
|
||||
if self.fontsize is None:
|
||||
# get figure to data transform
|
||||
# adjust fontsize to avoid overlap
|
||||
# get max box width and height
|
||||
extents = [ann.get_bbox_patch().get_window_extent()
|
||||
for ann in anns]
|
||||
max_width = max([extent.width for extent in extents])
|
||||
max_height = max([extent.height for extent in extents])
|
||||
# width should be around scale_x in axis coordinates
|
||||
size = anns[0].get_fontsize() * min(scale_x / max_width,
|
||||
scale_y / max_height)
|
||||
for ann in anns:
|
||||
ann.set_fontsize(size)
|
||||
|
||||
return anns
|
||||
|
||||
def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0):
|
||||
kwargs = dict(bbox=self.bbox_args, ha='center', va='center',
|
||||
zorder=100 - 10 * depth, xycoords='axes pixels')
|
||||
|
||||
if self.fontsize is not None:
|
||||
kwargs['fontsize'] = self.fontsize
|
||||
|
||||
# offset things by .5 to center them in plot
|
||||
xy = ((node.x + .5) * scale_x, height - (node.y + .5) * scale_y)
|
||||
|
||||
if self.max_depth is None or depth <= self.max_depth:
|
||||
if self.filled:
|
||||
kwargs['bbox']['fc'] = self.get_fill_color(tree,
|
||||
node.tree.node_id)
|
||||
if node.parent is None:
|
||||
# root
|
||||
ax.annotate(node.tree.label, xy, **kwargs)
|
||||
else:
|
||||
xy_parent = ((node.parent.x + .5) * scale_x,
|
||||
height - (node.parent.y + .5) * scale_y)
|
||||
kwargs["arrowprops"] = self.arrow_args
|
||||
ax.annotate(node.tree.label, xy_parent, xy, **kwargs)
|
||||
for child in node.children:
|
||||
self.recurse(child, tree, ax, scale_x, scale_y, height,
|
||||
depth=depth + 1)
|
||||
|
||||
else:
|
||||
xy_parent = ((node.parent.x + .5) * scale_x,
|
||||
height - (node.parent.y + .5) * scale_y)
|
||||
kwargs["arrowprops"] = self.arrow_args
|
||||
kwargs['bbox']['fc'] = 'grey'
|
||||
ax.annotate("\n (...) \n", xy_parent, xy, **kwargs)
|
||||
|
||||
|
||||
@_deprecate_positional_args
|
||||
def export_graphviz(decision_tree, out_file=None, *, max_depth=None,
|
||||
feature_names=None, class_names=None, label='all',
|
||||
filled=False, leaves_parallel=False, impurity=True,
|
||||
node_ids=False, proportion=False, rotate=False,
|
||||
rounded=False, special_characters=False, precision=3):
|
||||
"""Export a decision tree in DOT format.
|
||||
|
||||
This function generates a GraphViz representation of the decision tree,
|
||||
which is then written into `out_file`. Once exported, graphical renderings
|
||||
can be generated using, for example::
|
||||
|
||||
$ dot -Tps tree.dot -o tree.ps (PostScript format)
|
||||
$ dot -Tpng tree.dot -o tree.png (PNG format)
|
||||
|
||||
The sample counts that are shown are weighted with any sample_weights that
|
||||
might be present.
|
||||
|
||||
Read more in the :ref:`User Guide <tree>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decision_tree : decision tree classifier
|
||||
The decision tree to be exported to GraphViz.
|
||||
|
||||
out_file : file object or string, optional (default=None)
|
||||
Handle or name of the output file. If ``None``, the result is
|
||||
returned as a string.
|
||||
|
||||
.. versionchanged:: 0.20
|
||||
Default of out_file changed from "tree.dot" to None.
|
||||
|
||||
max_depth : int, optional (default=None)
|
||||
The maximum depth of the representation. If None, the tree is fully
|
||||
generated.
|
||||
|
||||
feature_names : list of strings, optional (default=None)
|
||||
Names of each of the features.
|
||||
|
||||
class_names : list of strings, bool or None, optional (default=None)
|
||||
Names of each of the target classes in ascending numerical order.
|
||||
Only relevant for classification and not supported for multi-output.
|
||||
If ``True``, shows a symbolic representation of the class name.
|
||||
|
||||
label : {'all', 'root', 'none'}, optional (default='all')
|
||||
Whether to show informative labels for impurity, etc.
|
||||
Options include 'all' to show at every node, 'root' to show only at
|
||||
the top root node, or 'none' to not show at any node.
|
||||
|
||||
filled : bool, optional (default=False)
|
||||
When set to ``True``, paint nodes to indicate majority class for
|
||||
classification, extremity of values for regression, or purity of node
|
||||
for multi-output.
|
||||
|
||||
leaves_parallel : bool, optional (default=False)
|
||||
When set to ``True``, draw all leaf nodes at the bottom of the tree.
|
||||
|
||||
impurity : bool, optional (default=True)
|
||||
When set to ``True``, show the impurity at each node.
|
||||
|
||||
node_ids : bool, optional (default=False)
|
||||
When set to ``True``, show the ID number on each node.
|
||||
|
||||
proportion : bool, optional (default=False)
|
||||
When set to ``True``, change the display of 'values' and/or 'samples'
|
||||
to be proportions and percentages respectively.
|
||||
|
||||
rotate : bool, optional (default=False)
|
||||
When set to ``True``, orient tree left to right rather than top-down.
|
||||
|
||||
rounded : bool, optional (default=False)
|
||||
When set to ``True``, draw node boxes with rounded corners and use
|
||||
Helvetica fonts instead of Times-Roman.
|
||||
|
||||
special_characters : bool, optional (default=False)
|
||||
When set to ``False``, ignore special characters for PostScript
|
||||
compatibility.
|
||||
|
||||
precision : int, optional (default=3)
|
||||
Number of digits of precision for floating point in the values of
|
||||
impurity, threshold and value attributes of each node.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dot_data : string
|
||||
String representation of the input tree in GraphViz dot format.
|
||||
Only returned if ``out_file`` is None.
|
||||
|
||||
.. versionadded:: 0.18
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.datasets import load_iris
|
||||
>>> from sklearn import tree
|
||||
|
||||
>>> clf = tree.DecisionTreeClassifier()
|
||||
>>> iris = load_iris()
|
||||
|
||||
>>> clf = clf.fit(iris.data, iris.target)
|
||||
>>> tree.export_graphviz(clf)
|
||||
'digraph Tree {...
|
||||
"""
|
||||
|
||||
check_is_fitted(decision_tree)
|
||||
own_file = False
|
||||
return_string = False
|
||||
try:
|
||||
if isinstance(out_file, str):
|
||||
out_file = open(out_file, "w", encoding="utf-8")
|
||||
own_file = True
|
||||
|
||||
if out_file is None:
|
||||
return_string = True
|
||||
out_file = StringIO()
|
||||
|
||||
exporter = _DOTTreeExporter(
|
||||
out_file=out_file, max_depth=max_depth,
|
||||
feature_names=feature_names, class_names=class_names, label=label,
|
||||
filled=filled, leaves_parallel=leaves_parallel, impurity=impurity,
|
||||
node_ids=node_ids, proportion=proportion, rotate=rotate,
|
||||
rounded=rounded, special_characters=special_characters,
|
||||
precision=precision)
|
||||
exporter.export(decision_tree)
|
||||
|
||||
if return_string:
|
||||
return exporter.out_file.getvalue()
|
||||
|
||||
finally:
|
||||
if own_file:
|
||||
out_file.close()
|
||||
|
||||
|
||||
def _compute_depth(tree, node):
|
||||
"""
|
||||
Returns the depth of the subtree rooted in node.
|
||||
"""
|
||||
def compute_depth_(current_node, current_depth,
|
||||
children_left, children_right, depths):
|
||||
depths += [current_depth]
|
||||
left = children_left[current_node]
|
||||
right = children_right[current_node]
|
||||
if left != -1 and right != -1:
|
||||
compute_depth_(left, current_depth+1,
|
||||
children_left, children_right, depths)
|
||||
compute_depth_(right, current_depth+1,
|
||||
children_left, children_right, depths)
|
||||
|
||||
depths = []
|
||||
compute_depth_(node, 1, tree.children_left, tree.children_right, depths)
|
||||
return max(depths)
|
||||
|
||||
|
||||
@_deprecate_positional_args
|
||||
def export_text(decision_tree, *, feature_names=None, max_depth=10,
|
||||
spacing=3, decimals=2, show_weights=False):
|
||||
"""Build a text report showing the rules of a decision tree.
|
||||
|
||||
Note that backwards compatibility may not be supported.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decision_tree : object
|
||||
The decision tree estimator to be exported.
|
||||
It can be an instance of
|
||||
DecisionTreeClassifier or DecisionTreeRegressor.
|
||||
|
||||
feature_names : list, optional (default=None)
|
||||
A list of length n_features containing the feature names.
|
||||
If None generic names will be used ("feature_0", "feature_1", ...).
|
||||
|
||||
max_depth : int, optional (default=10)
|
||||
Only the first max_depth levels of the tree are exported.
|
||||
Truncated branches will be marked with "...".
|
||||
|
||||
spacing : int, optional (default=3)
|
||||
Number of spaces between edges. The higher it is, the wider the result.
|
||||
|
||||
decimals : int, optional (default=2)
|
||||
Number of decimal digits to display.
|
||||
|
||||
show_weights : bool, optional (default=False)
|
||||
If true the classification weights will be exported on each leaf.
|
||||
The classification weights are the number of samples each class.
|
||||
|
||||
Returns
|
||||
-------
|
||||
report : string
|
||||
Text summary of all the rules in the decision tree.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> from sklearn.datasets import load_iris
|
||||
>>> from sklearn.tree import DecisionTreeClassifier
|
||||
>>> from sklearn.tree import export_text
|
||||
>>> iris = load_iris()
|
||||
>>> X = iris['data']
|
||||
>>> y = iris['target']
|
||||
>>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
|
||||
>>> decision_tree = decision_tree.fit(X, y)
|
||||
>>> r = export_text(decision_tree, feature_names=iris['feature_names'])
|
||||
>>> print(r)
|
||||
|--- petal width (cm) <= 0.80
|
||||
| |--- class: 0
|
||||
|--- petal width (cm) > 0.80
|
||||
| |--- petal width (cm) <= 1.75
|
||||
| | |--- class: 1
|
||||
| |--- petal width (cm) > 1.75
|
||||
| | |--- class: 2
|
||||
"""
|
||||
check_is_fitted(decision_tree)
|
||||
tree_ = decision_tree.tree_
|
||||
if is_classifier(decision_tree):
|
||||
class_names = decision_tree.classes_
|
||||
right_child_fmt = "{} {} <= {}\n"
|
||||
left_child_fmt = "{} {} > {}\n"
|
||||
truncation_fmt = "{} {}\n"
|
||||
|
||||
if max_depth < 0:
|
||||
raise ValueError("max_depth bust be >= 0, given %d" % max_depth)
|
||||
|
||||
if (feature_names is not None and
|
||||
len(feature_names) != tree_.n_features):
|
||||
raise ValueError("feature_names must contain "
|
||||
"%d elements, got %d" % (tree_.n_features,
|
||||
len(feature_names)))
|
||||
|
||||
if spacing <= 0:
|
||||
raise ValueError("spacing must be > 0, given %d" % spacing)
|
||||
|
||||
if decimals < 0:
|
||||
raise ValueError("decimals must be >= 0, given %d" % decimals)
|
||||
|
||||
if isinstance(decision_tree, DecisionTreeClassifier):
|
||||
value_fmt = "{}{} weights: {}\n"
|
||||
if not show_weights:
|
||||
value_fmt = "{}{}{}\n"
|
||||
else:
|
||||
value_fmt = "{}{} value: {}\n"
|
||||
|
||||
if feature_names:
|
||||
feature_names_ = [feature_names[i] if i != _tree.TREE_UNDEFINED
|
||||
else None for i in tree_.feature]
|
||||
else:
|
||||
feature_names_ = ["feature_{}".format(i) for i in tree_.feature]
|
||||
|
||||
export_text.report = ""
|
||||
|
||||
def _add_leaf(value, class_name, indent):
|
||||
val = ''
|
||||
is_classification = isinstance(decision_tree,
|
||||
DecisionTreeClassifier)
|
||||
if show_weights or not is_classification:
|
||||
val = ["{1:.{0}f}, ".format(decimals, v) for v in value]
|
||||
val = '['+''.join(val)[:-2]+']'
|
||||
if is_classification:
|
||||
val += ' class: ' + str(class_name)
|
||||
export_text.report += value_fmt.format(indent, '', val)
|
||||
|
||||
def print_tree_recurse(node, depth):
|
||||
indent = ("|" + (" " * spacing)) * depth
|
||||
indent = indent[:-spacing] + "-" * spacing
|
||||
|
||||
value = None
|
||||
if tree_.n_outputs == 1:
|
||||
value = tree_.value[node][0]
|
||||
else:
|
||||
value = tree_.value[node].T[0]
|
||||
class_name = np.argmax(value)
|
||||
|
||||
if (tree_.n_classes[0] != 1 and
|
||||
tree_.n_outputs == 1):
|
||||
class_name = class_names[class_name]
|
||||
|
||||
if depth <= max_depth+1:
|
||||
info_fmt = ""
|
||||
info_fmt_left = info_fmt
|
||||
info_fmt_right = info_fmt
|
||||
|
||||
if tree_.feature[node] != _tree.TREE_UNDEFINED:
|
||||
name = feature_names_[node]
|
||||
threshold = tree_.threshold[node]
|
||||
threshold = "{1:.{0}f}".format(decimals, threshold)
|
||||
export_text.report += right_child_fmt.format(indent,
|
||||
name,
|
||||
threshold)
|
||||
export_text.report += info_fmt_left
|
||||
print_tree_recurse(tree_.children_left[node], depth+1)
|
||||
|
||||
export_text.report += left_child_fmt.format(indent,
|
||||
name,
|
||||
threshold)
|
||||
export_text.report += info_fmt_right
|
||||
print_tree_recurse(tree_.children_right[node], depth+1)
|
||||
else: # leaf
|
||||
_add_leaf(value, class_name, indent)
|
||||
else:
|
||||
subtree_depth = _compute_depth(tree_, node)
|
||||
if subtree_depth == 1:
|
||||
_add_leaf(value, class_name, indent)
|
||||
else:
|
||||
trunc_report = 'truncated branch of depth %d' % subtree_depth
|
||||
export_text.report += truncation_fmt.format(indent,
|
||||
trunc_report)
|
||||
|
||||
print_tree_recurse(0, 1)
|
||||
return export_text.report
|
188
venv/Lib/site-packages/sklearn/tree/_reingold_tilford.py
Normal file
188
venv/Lib/site-packages/sklearn/tree/_reingold_tilford.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
# Authors: William Mill (bill@billmill.org)
|
||||
# License: BSD 3 clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DrawTree:
|
||||
def __init__(self, tree, parent=None, depth=0, number=1):
|
||||
self.x = -1.
|
||||
self.y = depth
|
||||
self.tree = tree
|
||||
self.children = [DrawTree(c, self, depth + 1, i + 1)
|
||||
for i, c
|
||||
in enumerate(tree.children)]
|
||||
self.parent = parent
|
||||
self.thread = None
|
||||
self.mod = 0
|
||||
self.ancestor = self
|
||||
self.change = self.shift = 0
|
||||
self._lmost_sibling = None
|
||||
# this is the number of the node in its group of siblings 1..n
|
||||
self.number = number
|
||||
|
||||
def left(self):
|
||||
return self.thread or len(self.children) and self.children[0]
|
||||
|
||||
def right(self):
|
||||
return self.thread or len(self.children) and self.children[-1]
|
||||
|
||||
def lbrother(self):
|
||||
n = None
|
||||
if self.parent:
|
||||
for node in self.parent.children:
|
||||
if node == self:
|
||||
return n
|
||||
else:
|
||||
n = node
|
||||
return n
|
||||
|
||||
def get_lmost_sibling(self):
|
||||
if not self._lmost_sibling and self.parent and self != \
|
||||
self.parent.children[0]:
|
||||
self._lmost_sibling = self.parent.children[0]
|
||||
return self._lmost_sibling
|
||||
lmost_sibling = property(get_lmost_sibling)
|
||||
|
||||
def __str__(self):
|
||||
return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def max_extents(self):
|
||||
extents = [c.max_extents() for c in self. children]
|
||||
extents.append((self.x, self.y))
|
||||
return np.max(extents, axis=0)
|
||||
|
||||
|
||||
def buchheim(tree):
|
||||
dt = first_walk(DrawTree(tree))
|
||||
min = second_walk(dt)
|
||||
if min < 0:
|
||||
third_walk(dt, -min)
|
||||
return dt
|
||||
|
||||
|
||||
def third_walk(tree, n):
|
||||
tree.x += n
|
||||
for c in tree.children:
|
||||
third_walk(c, n)
|
||||
|
||||
|
||||
def first_walk(v, distance=1.):
|
||||
if len(v.children) == 0:
|
||||
if v.lmost_sibling:
|
||||
v.x = v.lbrother().x + distance
|
||||
else:
|
||||
v.x = 0.
|
||||
else:
|
||||
default_ancestor = v.children[0]
|
||||
for w in v.children:
|
||||
first_walk(w)
|
||||
default_ancestor = apportion(w, default_ancestor, distance)
|
||||
# print("finished v =", v.tree, "children")
|
||||
execute_shifts(v)
|
||||
|
||||
midpoint = (v.children[0].x + v.children[-1].x) / 2
|
||||
|
||||
w = v.lbrother()
|
||||
if w:
|
||||
v.x = w.x + distance
|
||||
v.mod = v.x - midpoint
|
||||
else:
|
||||
v.x = midpoint
|
||||
return v
|
||||
|
||||
|
||||
def apportion(v, default_ancestor, distance):
|
||||
w = v.lbrother()
|
||||
if w is not None:
|
||||
# in buchheim notation:
|
||||
# i == inner; o == outer; r == right; l == left; r = +; l = -
|
||||
vir = vor = v
|
||||
vil = w
|
||||
vol = v.lmost_sibling
|
||||
sir = sor = v.mod
|
||||
sil = vil.mod
|
||||
sol = vol.mod
|
||||
while vil.right() and vir.left():
|
||||
vil = vil.right()
|
||||
vir = vir.left()
|
||||
vol = vol.left()
|
||||
vor = vor.right()
|
||||
vor.ancestor = v
|
||||
shift = (vil.x + sil) - (vir.x + sir) + distance
|
||||
if shift > 0:
|
||||
move_subtree(ancestor(vil, v, default_ancestor), v, shift)
|
||||
sir = sir + shift
|
||||
sor = sor + shift
|
||||
sil += vil.mod
|
||||
sir += vir.mod
|
||||
sol += vol.mod
|
||||
sor += vor.mod
|
||||
if vil.right() and not vor.right():
|
||||
vor.thread = vil.right()
|
||||
vor.mod += sil - sor
|
||||
else:
|
||||
if vir.left() and not vol.left():
|
||||
vol.thread = vir.left()
|
||||
vol.mod += sir - sol
|
||||
default_ancestor = v
|
||||
return default_ancestor
|
||||
|
||||
|
||||
def move_subtree(wl, wr, shift):
|
||||
subtrees = wr.number - wl.number
|
||||
# print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees,
|
||||
# 'shift', shift)
|
||||
# print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees
|
||||
wr.change -= shift / subtrees
|
||||
wr.shift += shift
|
||||
wl.change += shift / subtrees
|
||||
wr.x += shift
|
||||
wr.mod += shift
|
||||
|
||||
|
||||
def execute_shifts(v):
|
||||
shift = change = 0
|
||||
for w in v.children[::-1]:
|
||||
# print("shift:", w, shift, w.change)
|
||||
w.x += shift
|
||||
w.mod += shift
|
||||
change += w.change
|
||||
shift += w.shift + change
|
||||
|
||||
|
||||
def ancestor(vil, v, default_ancestor):
|
||||
# the relevant text is at the bottom of page 7 of
|
||||
# "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al,
|
||||
# (2002)
|
||||
# http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf
|
||||
if vil.ancestor in v.parent.children:
|
||||
return vil.ancestor
|
||||
else:
|
||||
return default_ancestor
|
||||
|
||||
|
||||
def second_walk(v, m=0, depth=0, min=None):
|
||||
v.x += m
|
||||
v.y = depth
|
||||
|
||||
if min is None or v.x < min:
|
||||
min = v.x
|
||||
|
||||
for w in v.children:
|
||||
min = second_walk(w, m + v.mod, depth + 1, min)
|
||||
|
||||
return min
|
||||
|
||||
|
||||
class Tree:
|
||||
def __init__(self, label="", node_id=-1, *children):
|
||||
self.label = label
|
||||
self.node_id = node_id
|
||||
if children:
|
||||
self.children = children
|
||||
else:
|
||||
self.children = []
|
BIN
venv/Lib/site-packages/sklearn/tree/_splitter.cp36-win32.pyd
Normal file
BIN
venv/Lib/site-packages/sklearn/tree/_splitter.cp36-win32.pyd
Normal file
Binary file not shown.
94
venv/Lib/site-packages/sklearn/tree/_splitter.pxd
Normal file
94
venv/Lib/site-packages/sklearn/tree/_splitter.pxd
Normal file
|
@ -0,0 +1,94 @@
|
|||
# Authors: Gilles Louppe <g.louppe@gmail.com>
|
||||
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
|
||||
# Brian Holt <bdholt1@gmail.com>
|
||||
# Joel Nothman <joel.nothman@gmail.com>
|
||||
# Arnaud Joly <arnaud.v.joly@gmail.com>
|
||||
# Jacob Schreiber <jmschreiber91@gmail.com>
|
||||
#
|
||||
# License: BSD 3 clause
|
||||
|
||||
# See _splitter.pyx for details.
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
from ._criterion cimport Criterion
|
||||
|
||||
from ._tree cimport DTYPE_t # Type of X
|
||||
from ._tree cimport DOUBLE_t # Type of y, sample_weight
|
||||
from ._tree cimport SIZE_t # Type for indices and counters
|
||||
from ._tree cimport INT32_t # Signed 32 bit integer
|
||||
from ._tree cimport UINT32_t # Unsigned 32 bit integer
|
||||
|
||||
cdef struct SplitRecord:
|
||||
# Data to track sample split
|
||||
SIZE_t feature # Which feature to split on.
|
||||
SIZE_t pos # Split samples array at the given position,
|
||||
# i.e. count of samples below threshold for feature.
|
||||
# pos is >= end if the node is a leaf.
|
||||
double threshold # Threshold to split at.
|
||||
double improvement # Impurity improvement given parent node.
|
||||
double impurity_left # Impurity of the left split.
|
||||
double impurity_right # Impurity of the right split.
|
||||
|
||||
cdef class Splitter:
|
||||
# The splitter searches in the input space for a feature and a threshold
|
||||
# to split the samples samples[start:end].
|
||||
#
|
||||
# The impurity computations are delegated to a criterion object.
|
||||
|
||||
# Internal structures
|
||||
cdef public Criterion criterion # Impurity criterion
|
||||
cdef public SIZE_t max_features # Number of features to test
|
||||
cdef public SIZE_t min_samples_leaf # Min samples in a leaf
|
||||
cdef public double min_weight_leaf # Minimum weight in a leaf
|
||||
|
||||
cdef object random_state # Random state
|
||||
cdef UINT32_t rand_r_state # sklearn_rand_r random number state
|
||||
|
||||
cdef SIZE_t* samples # Sample indices in X, y
|
||||
cdef SIZE_t n_samples # X.shape[0]
|
||||
cdef double weighted_n_samples # Weighted number of samples
|
||||
cdef SIZE_t* features # Feature indices in X
|
||||
cdef SIZE_t* constant_features # Constant features indices
|
||||
cdef SIZE_t n_features # X.shape[1]
|
||||
cdef DTYPE_t* feature_values # temp. array holding feature values
|
||||
|
||||
cdef SIZE_t start # Start position for the current node
|
||||
cdef SIZE_t end # End position for the current node
|
||||
|
||||
cdef const DOUBLE_t[:, ::1] y
|
||||
cdef DOUBLE_t* sample_weight
|
||||
|
||||
# The samples vector `samples` is maintained by the Splitter object such
|
||||
# that the samples contained in a node are contiguous. With this setting,
|
||||
# `node_split` reorganizes the node samples `samples[start:end]` in two
|
||||
# subsets `samples[start:pos]` and `samples[pos:end]`.
|
||||
|
||||
# The 1-d `features` array of size n_features contains the features
|
||||
# indices and allows fast sampling without replacement of features.
|
||||
|
||||
# The 1-d `constant_features` array of size n_features holds in
|
||||
# `constant_features[:n_constant_features]` the feature ids with
|
||||
# constant values for all the samples that reached a specific node.
|
||||
# The value `n_constant_features` is given by the parent node to its
|
||||
# child nodes. The content of the range `[n_constant_features:]` is left
|
||||
# undefined, but preallocated for performance reasons
|
||||
# This allows optimization with depth-based tree building.
|
||||
|
||||
# Methods
|
||||
cdef int init(self, object X, const DOUBLE_t[:, ::1] y,
|
||||
DOUBLE_t* sample_weight,
|
||||
np.ndarray X_idx_sorted=*) except -1
|
||||
|
||||
cdef int node_reset(self, SIZE_t start, SIZE_t end,
|
||||
double* weighted_n_node_samples) nogil except -1
|
||||
|
||||
cdef int node_split(self,
|
||||
double impurity, # Impurity of the node
|
||||
SplitRecord* split,
|
||||
SIZE_t* n_constant_features) nogil except -1
|
||||
|
||||
cdef void node_value(self, double* dest) nogil
|
||||
|
||||
cdef double node_impurity(self) nogil
|
BIN
venv/Lib/site-packages/sklearn/tree/_tree.cp36-win32.pyd
Normal file
BIN
venv/Lib/site-packages/sklearn/tree/_tree.cp36-win32.pyd
Normal file
Binary file not shown.
105
venv/Lib/site-packages/sklearn/tree/_tree.pxd
Normal file
105
venv/Lib/site-packages/sklearn/tree/_tree.pxd
Normal file
|
@ -0,0 +1,105 @@
|
|||
# Authors: Gilles Louppe <g.louppe@gmail.com>
|
||||
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
|
||||
# Brian Holt <bdholt1@gmail.com>
|
||||
# Joel Nothman <joel.nothman@gmail.com>
|
||||
# Arnaud Joly <arnaud.v.joly@gmail.com>
|
||||
# Jacob Schreiber <jmschreiber91@gmail.com>
|
||||
# Nelson Liu <nelson@nelsonliu.me>
|
||||
#
|
||||
# License: BSD 3 clause
|
||||
|
||||
# See _tree.pyx for details.
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
ctypedef np.npy_float32 DTYPE_t # Type of X
|
||||
ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
|
||||
ctypedef np.npy_intp SIZE_t # Type for indices and counters
|
||||
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
|
||||
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
|
||||
|
||||
from ._splitter cimport Splitter
|
||||
from ._splitter cimport SplitRecord
|
||||
|
||||
cdef struct Node:
|
||||
# Base storage structure for the nodes in a Tree object
|
||||
|
||||
SIZE_t left_child # id of the left child of the node
|
||||
SIZE_t right_child # id of the right child of the node
|
||||
SIZE_t feature # Feature used for splitting the node
|
||||
DOUBLE_t threshold # Threshold value at the node
|
||||
DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion)
|
||||
SIZE_t n_node_samples # Number of samples at the node
|
||||
DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node
|
||||
|
||||
|
||||
cdef class Tree:
|
||||
# The Tree object is a binary tree structure constructed by the
|
||||
# TreeBuilder. The tree structure is used for predictions and
|
||||
# feature importances.
|
||||
|
||||
# Input/Output layout
|
||||
cdef public SIZE_t n_features # Number of features in X
|
||||
cdef SIZE_t* n_classes # Number of classes in y[:, k]
|
||||
cdef public SIZE_t n_outputs # Number of outputs in y
|
||||
cdef public SIZE_t max_n_classes # max(n_classes)
|
||||
|
||||
# Inner structures: values are stored separately from node structure,
|
||||
# since size is determined at runtime.
|
||||
cdef public SIZE_t max_depth # Max depth of the tree
|
||||
cdef public SIZE_t node_count # Counter for node IDs
|
||||
cdef public SIZE_t capacity # Capacity of tree, in terms of nodes
|
||||
cdef Node* nodes # Array of nodes
|
||||
cdef double* value # (capacity, n_outputs, max_n_classes) array of values
|
||||
cdef SIZE_t value_stride # = n_outputs * max_n_classes
|
||||
|
||||
# Methods
|
||||
cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
|
||||
SIZE_t feature, double threshold, double impurity,
|
||||
SIZE_t n_node_samples,
|
||||
double weighted_n_samples) nogil except -1
|
||||
cdef int _resize(self, SIZE_t capacity) nogil except -1
|
||||
cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1
|
||||
|
||||
cdef np.ndarray _get_value_ndarray(self)
|
||||
cdef np.ndarray _get_node_ndarray(self)
|
||||
|
||||
cpdef np.ndarray predict(self, object X)
|
||||
|
||||
cpdef np.ndarray apply(self, object X)
|
||||
cdef np.ndarray _apply_dense(self, object X)
|
||||
cdef np.ndarray _apply_sparse_csr(self, object X)
|
||||
|
||||
cpdef object decision_path(self, object X)
|
||||
cdef object _decision_path_dense(self, object X)
|
||||
cdef object _decision_path_sparse_csr(self, object X)
|
||||
|
||||
cpdef compute_feature_importances(self, normalize=*)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tree builder
|
||||
# =============================================================================
|
||||
|
||||
cdef class TreeBuilder:
|
||||
# The TreeBuilder recursively builds a Tree object from training samples,
|
||||
# using a Splitter object for splitting internal nodes and assigning
|
||||
# values to leaves.
|
||||
#
|
||||
# This class controls the various stopping criteria and the node splitting
|
||||
# evaluation order, e.g. depth-first or best-first.
|
||||
|
||||
cdef Splitter splitter # Splitting algorithm
|
||||
|
||||
cdef SIZE_t min_samples_split # Minimum number of samples in an internal node
|
||||
cdef SIZE_t min_samples_leaf # Minimum number of samples in a leaf
|
||||
cdef double min_weight_leaf # Minimum weight in a leaf
|
||||
cdef SIZE_t max_depth # Maximal tree depth
|
||||
cdef double min_impurity_split
|
||||
cdef double min_impurity_decrease # Impurity threshold for early stopping
|
||||
|
||||
cpdef build(self, Tree tree, object X, np.ndarray y,
|
||||
np.ndarray sample_weight=*,
|
||||
np.ndarray X_idx_sorted=*)
|
||||
cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight)
|
BIN
venv/Lib/site-packages/sklearn/tree/_utils.cp36-win32.pyd
Normal file
BIN
venv/Lib/site-packages/sklearn/tree/_utils.cp36-win32.pyd
Normal file
Binary file not shown.
170
venv/Lib/site-packages/sklearn/tree/_utils.pxd
Normal file
170
venv/Lib/site-packages/sklearn/tree/_utils.pxd
Normal file
|
@ -0,0 +1,170 @@
|
|||
# Authors: Gilles Louppe <g.louppe@gmail.com>
|
||||
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
|
||||
# Arnaud Joly <arnaud.v.joly@gmail.com>
|
||||
# Jacob Schreiber <jmschreiber91@gmail.com>
|
||||
# Nelson Liu <nelson@nelsonliu.me>
|
||||
#
|
||||
# License: BSD 3 clause
|
||||
|
||||
# See _utils.pyx for details.
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
from ._tree cimport Node
|
||||
from ..neighbors._quad_tree cimport Cell
|
||||
|
||||
ctypedef np.npy_float32 DTYPE_t # Type of X
|
||||
ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
|
||||
ctypedef np.npy_intp SIZE_t # Type for indices and counters
|
||||
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
|
||||
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
|
||||
|
||||
|
||||
cdef enum:
|
||||
# Max value for our rand_r replacement (near the bottom).
|
||||
# We don't use RAND_MAX because it's different across platforms and
|
||||
# particularly tiny on Windows/MSVC.
|
||||
RAND_R_MAX = 0x7FFFFFFF
|
||||
|
||||
|
||||
# safe_realloc(&p, n) resizes the allocation of p to n * sizeof(*p) bytes or
|
||||
# raises a MemoryError. It never calls free, since that's __dealloc__'s job.
|
||||
# cdef DTYPE_t *p = NULL
|
||||
# safe_realloc(&p, n)
|
||||
# is equivalent to p = malloc(n * sizeof(*p)) with error checking.
|
||||
ctypedef fused realloc_ptr:
|
||||
# Add pointer types here as needed.
|
||||
(DTYPE_t*)
|
||||
(SIZE_t*)
|
||||
(unsigned char*)
|
||||
(WeightedPQueueRecord*)
|
||||
(DOUBLE_t*)
|
||||
(DOUBLE_t**)
|
||||
(Node*)
|
||||
(Cell*)
|
||||
(Node**)
|
||||
(StackRecord*)
|
||||
(PriorityHeapRecord*)
|
||||
|
||||
cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except *
|
||||
|
||||
|
||||
cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size)
|
||||
|
||||
|
||||
cdef SIZE_t rand_int(SIZE_t low, SIZE_t high,
|
||||
UINT32_t* random_state) nogil
|
||||
|
||||
|
||||
cdef double rand_uniform(double low, double high,
|
||||
UINT32_t* random_state) nogil
|
||||
|
||||
|
||||
cdef double log(double x) nogil
|
||||
|
||||
# =============================================================================
|
||||
# Stack data structure
|
||||
# =============================================================================
|
||||
|
||||
# A record on the stack for depth-first tree growing
|
||||
cdef struct StackRecord:
|
||||
SIZE_t start
|
||||
SIZE_t end
|
||||
SIZE_t depth
|
||||
SIZE_t parent
|
||||
bint is_left
|
||||
double impurity
|
||||
SIZE_t n_constant_features
|
||||
|
||||
cdef class Stack:
|
||||
cdef SIZE_t capacity
|
||||
cdef SIZE_t top
|
||||
cdef StackRecord* stack_
|
||||
|
||||
cdef bint is_empty(self) nogil
|
||||
cdef int push(self, SIZE_t start, SIZE_t end, SIZE_t depth, SIZE_t parent,
|
||||
bint is_left, double impurity,
|
||||
SIZE_t n_constant_features) nogil except -1
|
||||
cdef int pop(self, StackRecord* res) nogil
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PriorityHeap data structure
|
||||
# =============================================================================
|
||||
|
||||
# A record on the frontier for best-first tree growing
|
||||
cdef struct PriorityHeapRecord:
|
||||
SIZE_t node_id
|
||||
SIZE_t start
|
||||
SIZE_t end
|
||||
SIZE_t pos
|
||||
SIZE_t depth
|
||||
bint is_leaf
|
||||
double impurity
|
||||
double impurity_left
|
||||
double impurity_right
|
||||
double improvement
|
||||
|
||||
cdef class PriorityHeap:
|
||||
cdef SIZE_t capacity
|
||||
cdef SIZE_t heap_ptr
|
||||
cdef PriorityHeapRecord* heap_
|
||||
|
||||
cdef bint is_empty(self) nogil
|
||||
cdef void heapify_up(self, PriorityHeapRecord* heap, SIZE_t pos) nogil
|
||||
cdef void heapify_down(self, PriorityHeapRecord* heap, SIZE_t pos, SIZE_t heap_length) nogil
|
||||
cdef int push(self, SIZE_t node_id, SIZE_t start, SIZE_t end, SIZE_t pos,
|
||||
SIZE_t depth, bint is_leaf, double improvement,
|
||||
double impurity, double impurity_left,
|
||||
double impurity_right) nogil except -1
|
||||
cdef int pop(self, PriorityHeapRecord* res) nogil
|
||||
|
||||
# =============================================================================
|
||||
# WeightedPQueue data structure
|
||||
# =============================================================================
|
||||
|
||||
# A record stored in the WeightedPQueue
|
||||
cdef struct WeightedPQueueRecord:
|
||||
DOUBLE_t data
|
||||
DOUBLE_t weight
|
||||
|
||||
cdef class WeightedPQueue:
|
||||
cdef SIZE_t capacity
|
||||
cdef SIZE_t array_ptr
|
||||
cdef WeightedPQueueRecord* array_
|
||||
|
||||
cdef bint is_empty(self) nogil
|
||||
cdef int reset(self) nogil except -1
|
||||
cdef SIZE_t size(self) nogil
|
||||
cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1
|
||||
cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil
|
||||
cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
|
||||
cdef int peek(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
|
||||
cdef DOUBLE_t get_weight_from_index(self, SIZE_t index) nogil
|
||||
cdef DOUBLE_t get_value_from_index(self, SIZE_t index) nogil
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WeightedMedianCalculator data structure
|
||||
# =============================================================================
|
||||
|
||||
cdef class WeightedMedianCalculator:
|
||||
cdef SIZE_t initial_capacity
|
||||
cdef WeightedPQueue samples
|
||||
cdef DOUBLE_t total_weight
|
||||
cdef SIZE_t k
|
||||
cdef DOUBLE_t sum_w_0_k # represents sum(weights[0:k])
|
||||
# = w[0] + w[1] + ... + w[k-1]
|
||||
|
||||
cdef SIZE_t size(self) nogil
|
||||
cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1
|
||||
cdef int reset(self) nogil except -1
|
||||
cdef int update_median_parameters_post_push(
|
||||
self, DOUBLE_t data, DOUBLE_t weight,
|
||||
DOUBLE_t original_median) nogil
|
||||
cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil
|
||||
cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
|
||||
cdef int update_median_parameters_post_remove(
|
||||
self, DOUBLE_t data, DOUBLE_t weight,
|
||||
DOUBLE_t original_median) nogil
|
||||
cdef DOUBLE_t get_median(self) nogil
|
18
venv/Lib/site-packages/sklearn/tree/export.py
Normal file
18
venv/Lib/site-packages/sklearn/tree/export.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
|
||||
# THIS FILE WAS AUTOMATICALLY GENERATED BY deprecated_modules.py
|
||||
import sys
|
||||
# mypy error: Module X has no attribute y (typically for C extensions)
|
||||
from . import _export # type: ignore
|
||||
from ..externals._pep562 import Pep562
|
||||
from ..utils.deprecation import _raise_dep_warning_if_not_pytest
|
||||
|
||||
deprecated_path = 'sklearn.tree.export'
|
||||
correct_import_path = 'sklearn.tree'
|
||||
|
||||
_raise_dep_warning_if_not_pytest(deprecated_path, correct_import_path)
|
||||
|
||||
def __getattr__(name):
|
||||
return getattr(_export, name)
|
||||
|
||||
if not sys.version_info >= (3, 7):
|
||||
Pep562(__name__)
|
39
venv/Lib/site-packages/sklearn/tree/setup.py
Normal file
39
venv/Lib/site-packages/sklearn/tree/setup.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
import os
|
||||
|
||||
import numpy
|
||||
from numpy.distutils.misc_util import Configuration
|
||||
|
||||
|
||||
def configuration(parent_package="", top_path=None):
|
||||
config = Configuration("tree", parent_package, top_path)
|
||||
libraries = []
|
||||
if os.name == 'posix':
|
||||
libraries.append('m')
|
||||
config.add_extension("_tree",
|
||||
sources=["_tree.pyx"],
|
||||
include_dirs=[numpy.get_include()],
|
||||
libraries=libraries,
|
||||
extra_compile_args=["-O3"])
|
||||
config.add_extension("_splitter",
|
||||
sources=["_splitter.pyx"],
|
||||
include_dirs=[numpy.get_include()],
|
||||
libraries=libraries,
|
||||
extra_compile_args=["-O3"])
|
||||
config.add_extension("_criterion",
|
||||
sources=["_criterion.pyx"],
|
||||
include_dirs=[numpy.get_include()],
|
||||
libraries=libraries,
|
||||
extra_compile_args=["-O3"])
|
||||
config.add_extension("_utils",
|
||||
sources=["_utils.pyx"],
|
||||
include_dirs=[numpy.get_include()],
|
||||
libraries=libraries,
|
||||
extra_compile_args=["-O3"])
|
||||
|
||||
config.add_subpackage("tests")
|
||||
|
||||
return config
|
||||
|
||||
if __name__ == "__main__":
|
||||
from numpy.distutils.core import setup
|
||||
setup(**configuration().todict())
|
0
venv/Lib/site-packages/sklearn/tree/tests/__init__.py
Normal file
0
venv/Lib/site-packages/sklearn/tree/tests/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
469
venv/Lib/site-packages/sklearn/tree/tests/test_export.py
Normal file
469
venv/Lib/site-packages/sklearn/tree/tests/test_export.py
Normal file
|
@ -0,0 +1,469 @@
|
|||
"""
|
||||
Testing for export functions of decision trees (sklearn.tree.export).
|
||||
"""
|
||||
from re import finditer, search
|
||||
from textwrap import dedent
|
||||
|
||||
from numpy.random import RandomState
|
||||
import pytest
|
||||
|
||||
from sklearn.base import is_classifier
|
||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
from sklearn.tree import export_graphviz, plot_tree, export_text
|
||||
from io import StringIO
|
||||
from sklearn.exceptions import NotFittedError
|
||||
|
||||
# toy sample
|
||||
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
|
||||
y = [-1, -1, -1, 1, 1, 1]
|
||||
y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
|
||||
w = [1, 1, 1, .5, .5, .5]
|
||||
y_degraded = [1, 1, 1, 1, 1, 1]
|
||||
|
||||
|
||||
def test_graphviz_toy():
|
||||
# Check correctness of export_graphviz
|
||||
clf = DecisionTreeClassifier(max_depth=3,
|
||||
min_samples_split=2,
|
||||
criterion="gini",
|
||||
random_state=2)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
contents1 = export_graphviz(clf, out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box] ;\n' \
|
||||
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||||
'value = [3, 3]"] ;\n' \
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test with feature_names
|
||||
contents1 = export_graphviz(clf, feature_names=["feature0", "feature1"],
|
||||
out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box] ;\n' \
|
||||
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||||
'value = [3, 3]"] ;\n' \
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test with class_names
|
||||
contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box] ;\n' \
|
||||
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||||
'value = [3, 3]\\nclass = yes"] ;\n' \
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' \
|
||||
'class = yes"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n' \
|
||||
'class = no"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test plot_options
|
||||
contents1 = export_graphviz(clf, filled=True, impurity=False,
|
||||
proportion=True, special_characters=True,
|
||||
rounded=True, out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box, style="filled, rounded", color="black", ' \
|
||||
'fontname=helvetica] ;\n' \
|
||||
'edge [fontname=helvetica] ;\n' \
|
||||
'0 [label=<X<SUB>0</SUB> ≤ 0.0<br/>samples = 100.0%<br/>' \
|
||||
'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n' \
|
||||
'1 [label=<samples = 50.0%<br/>value = [1.0, 0.0]>, ' \
|
||||
'fillcolor="#e58139"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label=<samples = 50.0%<br/>value = [0.0, 1.0]>, ' \
|
||||
'fillcolor="#399de5"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test max_depth
|
||||
contents1 = export_graphviz(clf, max_depth=0,
|
||||
class_names=True, out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box] ;\n' \
|
||||
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||||
'value = [3, 3]\\nclass = y[0]"] ;\n' \
|
||||
'1 [label="(...)"] ;\n' \
|
||||
'0 -> 1 ;\n' \
|
||||
'2 [label="(...)"] ;\n' \
|
||||
'0 -> 2 ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test max_depth with plot_options
|
||||
contents1 = export_graphviz(clf, max_depth=0, filled=True,
|
||||
out_file=None, node_ids=True)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box, style="filled", color="black"] ;\n' \
|
||||
'0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n' \
|
||||
'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n' \
|
||||
'1 [label="(...)", fillcolor="#C0C0C0"] ;\n' \
|
||||
'0 -> 1 ;\n' \
|
||||
'2 [label="(...)", fillcolor="#C0C0C0"] ;\n' \
|
||||
'0 -> 2 ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test multi-output with weighted samples
|
||||
clf = DecisionTreeClassifier(max_depth=2,
|
||||
min_samples_split=2,
|
||||
criterion="gini",
|
||||
random_state=2)
|
||||
clf = clf.fit(X, y2, sample_weight=w)
|
||||
|
||||
contents1 = export_graphviz(clf, filled=True,
|
||||
impurity=False, out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box, style="filled", color="black"] ;\n' \
|
||||
'0 [label="X[0] <= 0.0\\nsamples = 6\\n' \
|
||||
'value = [[3.0, 1.5, 0.0]\\n' \
|
||||
'[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n' \
|
||||
'1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n' \
|
||||
'[3, 0, 0]]", fillcolor="#e58139"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label="X[0] <= 1.5\\nsamples = 3\\n' \
|
||||
'value = [[0.0, 1.5, 0.0]\\n' \
|
||||
'[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n' \
|
||||
'[0, 1, 0]]", fillcolor="#e58139"] ;\n' \
|
||||
'2 -> 3 ;\n' \
|
||||
'4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n' \
|
||||
'[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n' \
|
||||
'2 -> 4 ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test regression output with plot_options
|
||||
clf = DecisionTreeRegressor(max_depth=3,
|
||||
min_samples_split=2,
|
||||
criterion="mse",
|
||||
random_state=2)
|
||||
clf.fit(X, y)
|
||||
|
||||
contents1 = export_graphviz(clf, filled=True, leaves_parallel=True,
|
||||
out_file=None, rotate=True, rounded=True)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box, style="filled, rounded", color="black", ' \
|
||||
'fontname=helvetica] ;\n' \
|
||||
'graph [ranksep=equally, splines=polyline] ;\n' \
|
||||
'edge [fontname=helvetica] ;\n' \
|
||||
'rankdir=LR ;\n' \
|
||||
'0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \
|
||||
'value = 0.0", fillcolor="#f2c09c"] ;\n' \
|
||||
'1 [label="mse = 0.0\\nsamples = 3\\nvalue = -1.0", ' \
|
||||
'fillcolor="#ffffff"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label="mse = 0.0\\nsamples = 3\\nvalue = 1.0", ' \
|
||||
'fillcolor="#e58139"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'{rank=same ; 0} ;\n' \
|
||||
'{rank=same ; 1; 2} ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test classifier with degraded learning set
|
||||
clf = DecisionTreeClassifier(max_depth=3)
|
||||
clf.fit(X, y_degraded)
|
||||
|
||||
contents1 = export_graphviz(clf, filled=True, out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box, style="filled", color="black"] ;\n' \
|
||||
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", ' \
|
||||
'fillcolor="#ffffff"] ;\n' \
|
||||
'}'
|
||||
|
||||
|
||||
def test_graphviz_errors():
|
||||
# Check for errors of export_graphviz
|
||||
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
|
||||
|
||||
# Check not-fitted decision tree error
|
||||
out = StringIO()
|
||||
with pytest.raises(NotFittedError):
|
||||
export_graphviz(clf, out)
|
||||
|
||||
clf.fit(X, y)
|
||||
|
||||
# Check if it errors when length of feature_names
|
||||
# mismatches with number of features
|
||||
message = ("Length of feature_names, "
|
||||
"1 does not match number of features, 2")
|
||||
with pytest.raises(ValueError, match=message):
|
||||
export_graphviz(clf, None, feature_names=["a"])
|
||||
|
||||
message = ("Length of feature_names, "
|
||||
"3 does not match number of features, 2")
|
||||
with pytest.raises(ValueError, match=message):
|
||||
export_graphviz(clf, None, feature_names=["a", "b", "c"])
|
||||
|
||||
# Check error when argument is not an estimator
|
||||
message = "is not an estimator instance"
|
||||
with pytest.raises(TypeError, match=message):
|
||||
export_graphviz(clf.fit(X, y).tree_)
|
||||
|
||||
# Check class_names error
|
||||
out = StringIO()
|
||||
with pytest.raises(IndexError):
|
||||
export_graphviz(clf, out, class_names=[])
|
||||
|
||||
# Check precision error
|
||||
out = StringIO()
|
||||
with pytest.raises(ValueError, match="should be greater or equal"):
|
||||
export_graphviz(clf, out, precision=-1)
|
||||
with pytest.raises(ValueError, match="should be an integer"):
|
||||
export_graphviz(clf, out, precision="1")
|
||||
|
||||
|
||||
def test_friedman_mse_in_graphviz():
|
||||
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
|
||||
clf.fit(X, y)
|
||||
dot_data = StringIO()
|
||||
export_graphviz(clf, out_file=dot_data)
|
||||
|
||||
clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
for estimator in clf.estimators_:
|
||||
export_graphviz(estimator[0], out_file=dot_data)
|
||||
|
||||
for finding in finditer(r"\[.*?samples.*?\]", dot_data.getvalue()):
|
||||
assert "friedman_mse" in finding.group()
|
||||
|
||||
|
||||
def test_precision():
|
||||
|
||||
rng_reg = RandomState(2)
|
||||
rng_clf = RandomState(8)
|
||||
for X, y, clf in zip(
|
||||
(rng_reg.random_sample((5, 2)),
|
||||
rng_clf.random_sample((1000, 4))),
|
||||
(rng_reg.random_sample((5, )),
|
||||
rng_clf.randint(2, size=(1000, ))),
|
||||
(DecisionTreeRegressor(criterion="friedman_mse", random_state=0,
|
||||
max_depth=1),
|
||||
DecisionTreeClassifier(max_depth=1, random_state=0))):
|
||||
|
||||
clf.fit(X, y)
|
||||
for precision in (4, 3):
|
||||
dot_data = export_graphviz(clf, out_file=None, precision=precision,
|
||||
proportion=True)
|
||||
|
||||
# With the current random state, the impurity and the threshold
|
||||
# will have the number of precision set in the export_graphviz
|
||||
# function. We will check the number of precision with a strict
|
||||
# equality. The value reported will have only 2 precision and
|
||||
# therefore, only a less equal comparison will be done.
|
||||
|
||||
# check value
|
||||
for finding in finditer(r"value = \d+\.\d+", dot_data):
|
||||
assert (
|
||||
len(search(r"\.\d+", finding.group()).group()) <=
|
||||
precision + 1)
|
||||
# check impurity
|
||||
if is_classifier(clf):
|
||||
pattern = r"gini = \d+\.\d+"
|
||||
else:
|
||||
pattern = r"friedman_mse = \d+\.\d+"
|
||||
|
||||
# check impurity
|
||||
for finding in finditer(pattern, dot_data):
|
||||
assert (len(search(r"\.\d+", finding.group()).group()) ==
|
||||
precision + 1)
|
||||
# check threshold
|
||||
for finding in finditer(r"<= \d+\.\d+", dot_data):
|
||||
assert (len(search(r"\.\d+", finding.group()).group()) ==
|
||||
precision + 1)
|
||||
|
||||
|
||||
def test_export_text_errors():
|
||||
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
|
||||
err_msg = "max_depth bust be >= 0, given -1"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, max_depth=-1)
|
||||
err_msg = "feature_names must contain 2 elements, got 1"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, feature_names=['a'])
|
||||
err_msg = "decimals must be >= 0, given -1"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, decimals=-1)
|
||||
err_msg = "spacing must be > 0, given 0"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, spacing=0)
|
||||
|
||||
|
||||
def test_export_text():
|
||||
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
|
||||
expected_report = dedent("""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- class: 1
|
||||
""").lstrip()
|
||||
|
||||
assert export_text(clf) == expected_report
|
||||
# testing that leaves at level 1 are not truncated
|
||||
assert export_text(clf, max_depth=0) == expected_report
|
||||
# testing that the rest of the tree is truncated
|
||||
assert export_text(clf, max_depth=10) == expected_report
|
||||
|
||||
expected_report = dedent("""
|
||||
|--- b <= 0.00
|
||||
| |--- class: -1
|
||||
|--- b > 0.00
|
||||
| |--- class: 1
|
||||
""").lstrip()
|
||||
assert export_text(clf, feature_names=['a', 'b']) == expected_report
|
||||
|
||||
expected_report = dedent("""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- weights: [3.00, 0.00] class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- weights: [0.00, 3.00] class: 1
|
||||
""").lstrip()
|
||||
assert export_text(clf, show_weights=True) == expected_report
|
||||
|
||||
expected_report = dedent("""
|
||||
|- feature_1 <= 0.00
|
||||
| |- class: -1
|
||||
|- feature_1 > 0.00
|
||||
| |- class: 1
|
||||
""").lstrip()
|
||||
assert export_text(clf, spacing=1) == expected_report
|
||||
|
||||
X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]]
|
||||
y_l = [-1, -1, -1, 1, 1, 1, 2]
|
||||
clf = DecisionTreeClassifier(max_depth=4, random_state=0)
|
||||
clf.fit(X_l, y_l)
|
||||
expected_report = dedent("""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- truncated branch of depth 2
|
||||
""").lstrip()
|
||||
assert export_text(clf, max_depth=0) == expected_report
|
||||
|
||||
X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
|
||||
y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]]
|
||||
|
||||
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
|
||||
reg.fit(X_mo, y_mo)
|
||||
|
||||
expected_report = dedent("""
|
||||
|--- feature_1 <= 0.0
|
||||
| |--- value: [-1.0, -1.0]
|
||||
|--- feature_1 > 0.0
|
||||
| |--- value: [1.0, 1.0]
|
||||
""").lstrip()
|
||||
assert export_text(reg, decimals=1) == expected_report
|
||||
assert export_text(reg, decimals=1, show_weights=True) == expected_report
|
||||
|
||||
X_single = [[-2], [-1], [-1], [1], [1], [2]]
|
||||
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
|
||||
reg.fit(X_single, y_mo)
|
||||
|
||||
expected_report = dedent("""
|
||||
|--- first <= 0.0
|
||||
| |--- value: [-1.0, -1.0]
|
||||
|--- first > 0.0
|
||||
| |--- value: [1.0, 1.0]
|
||||
""").lstrip()
|
||||
assert export_text(reg, decimals=1,
|
||||
feature_names=['first']) == expected_report
|
||||
assert export_text(reg, decimals=1, show_weights=True,
|
||||
feature_names=['first']) == expected_report
|
||||
|
||||
|
||||
def test_plot_tree_entropy(pyplot):
|
||||
# mostly smoke tests
|
||||
# Check correctness of export_graphviz for criterion = entropy
|
||||
clf = DecisionTreeClassifier(max_depth=3,
|
||||
min_samples_split=2,
|
||||
criterion="entropy",
|
||||
random_state=2)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
feature_names = ['first feat', 'sepal_width']
|
||||
nodes = plot_tree(clf, feature_names=feature_names)
|
||||
assert len(nodes) == 3
|
||||
assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 1.0\n"
|
||||
"samples = 6\nvalue = [3, 3]")
|
||||
assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
|
||||
assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"
|
||||
|
||||
|
||||
def test_plot_tree_gini(pyplot):
|
||||
# mostly smoke tests
|
||||
# Check correctness of export_graphviz for criterion = gini
|
||||
clf = DecisionTreeClassifier(max_depth=3,
|
||||
min_samples_split=2,
|
||||
criterion="gini",
|
||||
random_state=2)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
feature_names = ['first feat', 'sepal_width']
|
||||
nodes = plot_tree(clf, feature_names=feature_names)
|
||||
assert len(nodes) == 3
|
||||
assert nodes[0].get_text() == ("first feat <= 0.0\ngini = 0.5\n"
|
||||
"samples = 6\nvalue = [3, 3]")
|
||||
assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]"
|
||||
assert nodes[2].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"
|
||||
|
||||
|
||||
# FIXME: to be removed in 0.25
|
||||
def test_plot_tree_rotate_deprecation(pyplot):
|
||||
tree = DecisionTreeClassifier()
|
||||
tree.fit(X, y)
|
||||
# test that a warning is raised when rotate is used.
|
||||
match = ("'rotate' has no effect and is deprecated in 0.23. "
|
||||
"It will be removed in 0.25.")
|
||||
with pytest.warns(FutureWarning, match=match):
|
||||
plot_tree(tree, rotate=True)
|
||||
|
||||
|
||||
def test_not_fitted_tree(pyplot):
|
||||
|
||||
# Testing if not fitted tree throws the correct error
|
||||
clf = DecisionTreeRegressor()
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_tree(clf)
|
|
@ -0,0 +1,52 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from sklearn.tree._reingold_tilford import buchheim, Tree
|
||||
|
||||
simple_tree = Tree("", 0,
|
||||
Tree("", 1),
|
||||
Tree("", 2))
|
||||
|
||||
bigger_tree = Tree("", 0,
|
||||
Tree("", 1,
|
||||
Tree("", 3),
|
||||
Tree("", 4,
|
||||
Tree("", 7),
|
||||
Tree("", 8)
|
||||
),
|
||||
),
|
||||
Tree("", 2,
|
||||
Tree("", 5),
|
||||
Tree("", 6)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)])
|
||||
def test_buchheim(tree, n_nodes):
|
||||
def walk_tree(draw_tree):
|
||||
res = [(draw_tree.x, draw_tree.y)]
|
||||
for child in draw_tree.children:
|
||||
# parents higher than children:
|
||||
assert child.y == draw_tree.y + 1
|
||||
res.extend(walk_tree(child))
|
||||
if len(draw_tree.children):
|
||||
# these trees are always binary
|
||||
# parents are centered above children
|
||||
assert draw_tree.x == (draw_tree.children[0].x
|
||||
+ draw_tree.children[1].x) / 2
|
||||
return res
|
||||
|
||||
layout = buchheim(tree)
|
||||
coordinates = walk_tree(layout)
|
||||
assert len(coordinates) == n_nodes
|
||||
# test that x values are unique per depth / level
|
||||
# we could also do it quicker using defaultdicts..
|
||||
depth = 0
|
||||
while True:
|
||||
x_at_this_depth = [node[0] for node in coordinates
|
||||
if node[1] == depth]
|
||||
if not x_at_this_depth:
|
||||
# reached all leafs
|
||||
break
|
||||
assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth)
|
||||
depth += 1
|
1966
venv/Lib/site-packages/sklearn/tree/tests/test_tree.py
Normal file
1966
venv/Lib/site-packages/sklearn/tree/tests/test_tree.py
Normal file
File diff suppressed because it is too large
Load diff
18
venv/Lib/site-packages/sklearn/tree/tree.py
Normal file
18
venv/Lib/site-packages/sklearn/tree/tree.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
|
||||
# THIS FILE WAS AUTOMATICALLY GENERATED BY deprecated_modules.py
|
||||
import sys
|
||||
# mypy error: Module X has no attribute y (typically for C extensions)
|
||||
from . import _classes # type: ignore
|
||||
from ..externals._pep562 import Pep562
|
||||
from ..utils.deprecation import _raise_dep_warning_if_not_pytest
|
||||
|
||||
deprecated_path = 'sklearn.tree.tree'
|
||||
correct_import_path = 'sklearn.tree'
|
||||
|
||||
_raise_dep_warning_if_not_pytest(deprecated_path, correct_import_path)
|
||||
|
||||
def __getattr__(name):
|
||||
return getattr(_classes, name)
|
||||
|
||||
if not sys.version_info >= (3, 7):
|
||||
Pep562(__name__)
|
Loading…
Add table
Add a link
Reference in a new issue