Uploaded Test files
This commit is contained in:
parent
f584ad9d97
commit
2e81cb7d99
16627 changed files with 2065359 additions and 102444 deletions
0
venv/Lib/site-packages/sklearn/tree/tests/__init__.py
Normal file
0
venv/Lib/site-packages/sklearn/tree/tests/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
469
venv/Lib/site-packages/sklearn/tree/tests/test_export.py
Normal file
469
venv/Lib/site-packages/sklearn/tree/tests/test_export.py
Normal file
|
@ -0,0 +1,469 @@
|
|||
"""
|
||||
Testing for export functions of decision trees (sklearn.tree.export).
|
||||
"""
|
||||
from re import finditer, search
|
||||
from textwrap import dedent
|
||||
|
||||
from numpy.random import RandomState
|
||||
import pytest
|
||||
|
||||
from sklearn.base import is_classifier
|
||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
from sklearn.tree import export_graphviz, plot_tree, export_text
|
||||
from io import StringIO
|
||||
from sklearn.exceptions import NotFittedError
|
||||
|
||||
# toy sample
|
||||
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
|
||||
y = [-1, -1, -1, 1, 1, 1]
|
||||
y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
|
||||
w = [1, 1, 1, .5, .5, .5]
|
||||
y_degraded = [1, 1, 1, 1, 1, 1]
|
||||
|
||||
|
||||
def test_graphviz_toy():
|
||||
# Check correctness of export_graphviz
|
||||
clf = DecisionTreeClassifier(max_depth=3,
|
||||
min_samples_split=2,
|
||||
criterion="gini",
|
||||
random_state=2)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
contents1 = export_graphviz(clf, out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box] ;\n' \
|
||||
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||||
'value = [3, 3]"] ;\n' \
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test with feature_names
|
||||
contents1 = export_graphviz(clf, feature_names=["feature0", "feature1"],
|
||||
out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box] ;\n' \
|
||||
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||||
'value = [3, 3]"] ;\n' \
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test with class_names
|
||||
contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box] ;\n' \
|
||||
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||||
'value = [3, 3]\\nclass = yes"] ;\n' \
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' \
|
||||
'class = yes"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n' \
|
||||
'class = no"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test plot_options
|
||||
contents1 = export_graphviz(clf, filled=True, impurity=False,
|
||||
proportion=True, special_characters=True,
|
||||
rounded=True, out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box, style="filled, rounded", color="black", ' \
|
||||
'fontname=helvetica] ;\n' \
|
||||
'edge [fontname=helvetica] ;\n' \
|
||||
'0 [label=<X<SUB>0</SUB> ≤ 0.0<br/>samples = 100.0%<br/>' \
|
||||
'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n' \
|
||||
'1 [label=<samples = 50.0%<br/>value = [1.0, 0.0]>, ' \
|
||||
'fillcolor="#e58139"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label=<samples = 50.0%<br/>value = [0.0, 1.0]>, ' \
|
||||
'fillcolor="#399de5"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test max_depth
|
||||
contents1 = export_graphviz(clf, max_depth=0,
|
||||
class_names=True, out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box] ;\n' \
|
||||
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
|
||||
'value = [3, 3]\\nclass = y[0]"] ;\n' \
|
||||
'1 [label="(...)"] ;\n' \
|
||||
'0 -> 1 ;\n' \
|
||||
'2 [label="(...)"] ;\n' \
|
||||
'0 -> 2 ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test max_depth with plot_options
|
||||
contents1 = export_graphviz(clf, max_depth=0, filled=True,
|
||||
out_file=None, node_ids=True)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box, style="filled", color="black"] ;\n' \
|
||||
'0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n' \
|
||||
'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n' \
|
||||
'1 [label="(...)", fillcolor="#C0C0C0"] ;\n' \
|
||||
'0 -> 1 ;\n' \
|
||||
'2 [label="(...)", fillcolor="#C0C0C0"] ;\n' \
|
||||
'0 -> 2 ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test multi-output with weighted samples
|
||||
clf = DecisionTreeClassifier(max_depth=2,
|
||||
min_samples_split=2,
|
||||
criterion="gini",
|
||||
random_state=2)
|
||||
clf = clf.fit(X, y2, sample_weight=w)
|
||||
|
||||
contents1 = export_graphviz(clf, filled=True,
|
||||
impurity=False, out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box, style="filled", color="black"] ;\n' \
|
||||
'0 [label="X[0] <= 0.0\\nsamples = 6\\n' \
|
||||
'value = [[3.0, 1.5, 0.0]\\n' \
|
||||
'[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n' \
|
||||
'1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n' \
|
||||
'[3, 0, 0]]", fillcolor="#e58139"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label="X[0] <= 1.5\\nsamples = 3\\n' \
|
||||
'value = [[0.0, 1.5, 0.0]\\n' \
|
||||
'[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n' \
|
||||
'[0, 1, 0]]", fillcolor="#e58139"] ;\n' \
|
||||
'2 -> 3 ;\n' \
|
||||
'4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n' \
|
||||
'[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n' \
|
||||
'2 -> 4 ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test regression output with plot_options
|
||||
clf = DecisionTreeRegressor(max_depth=3,
|
||||
min_samples_split=2,
|
||||
criterion="mse",
|
||||
random_state=2)
|
||||
clf.fit(X, y)
|
||||
|
||||
contents1 = export_graphviz(clf, filled=True, leaves_parallel=True,
|
||||
out_file=None, rotate=True, rounded=True)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box, style="filled, rounded", color="black", ' \
|
||||
'fontname=helvetica] ;\n' \
|
||||
'graph [ranksep=equally, splines=polyline] ;\n' \
|
||||
'edge [fontname=helvetica] ;\n' \
|
||||
'rankdir=LR ;\n' \
|
||||
'0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \
|
||||
'value = 0.0", fillcolor="#f2c09c"] ;\n' \
|
||||
'1 [label="mse = 0.0\\nsamples = 3\\nvalue = -1.0", ' \
|
||||
'fillcolor="#ffffff"] ;\n' \
|
||||
'0 -> 1 [labeldistance=2.5, labelangle=-45, ' \
|
||||
'headlabel="True"] ;\n' \
|
||||
'2 [label="mse = 0.0\\nsamples = 3\\nvalue = 1.0", ' \
|
||||
'fillcolor="#e58139"] ;\n' \
|
||||
'0 -> 2 [labeldistance=2.5, labelangle=45, ' \
|
||||
'headlabel="False"] ;\n' \
|
||||
'{rank=same ; 0} ;\n' \
|
||||
'{rank=same ; 1; 2} ;\n' \
|
||||
'}'
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test classifier with degraded learning set
|
||||
clf = DecisionTreeClassifier(max_depth=3)
|
||||
clf.fit(X, y_degraded)
|
||||
|
||||
contents1 = export_graphviz(clf, filled=True, out_file=None)
|
||||
contents2 = 'digraph Tree {\n' \
|
||||
'node [shape=box, style="filled", color="black"] ;\n' \
|
||||
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", ' \
|
||||
'fillcolor="#ffffff"] ;\n' \
|
||||
'}'
|
||||
|
||||
|
||||
def test_graphviz_errors():
|
||||
# Check for errors of export_graphviz
|
||||
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
|
||||
|
||||
# Check not-fitted decision tree error
|
||||
out = StringIO()
|
||||
with pytest.raises(NotFittedError):
|
||||
export_graphviz(clf, out)
|
||||
|
||||
clf.fit(X, y)
|
||||
|
||||
# Check if it errors when length of feature_names
|
||||
# mismatches with number of features
|
||||
message = ("Length of feature_names, "
|
||||
"1 does not match number of features, 2")
|
||||
with pytest.raises(ValueError, match=message):
|
||||
export_graphviz(clf, None, feature_names=["a"])
|
||||
|
||||
message = ("Length of feature_names, "
|
||||
"3 does not match number of features, 2")
|
||||
with pytest.raises(ValueError, match=message):
|
||||
export_graphviz(clf, None, feature_names=["a", "b", "c"])
|
||||
|
||||
# Check error when argument is not an estimator
|
||||
message = "is not an estimator instance"
|
||||
with pytest.raises(TypeError, match=message):
|
||||
export_graphviz(clf.fit(X, y).tree_)
|
||||
|
||||
# Check class_names error
|
||||
out = StringIO()
|
||||
with pytest.raises(IndexError):
|
||||
export_graphviz(clf, out, class_names=[])
|
||||
|
||||
# Check precision error
|
||||
out = StringIO()
|
||||
with pytest.raises(ValueError, match="should be greater or equal"):
|
||||
export_graphviz(clf, out, precision=-1)
|
||||
with pytest.raises(ValueError, match="should be an integer"):
|
||||
export_graphviz(clf, out, precision="1")
|
||||
|
||||
|
||||
def test_friedman_mse_in_graphviz():
|
||||
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
|
||||
clf.fit(X, y)
|
||||
dot_data = StringIO()
|
||||
export_graphviz(clf, out_file=dot_data)
|
||||
|
||||
clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
for estimator in clf.estimators_:
|
||||
export_graphviz(estimator[0], out_file=dot_data)
|
||||
|
||||
for finding in finditer(r"\[.*?samples.*?\]", dot_data.getvalue()):
|
||||
assert "friedman_mse" in finding.group()
|
||||
|
||||
|
||||
def test_precision():
|
||||
|
||||
rng_reg = RandomState(2)
|
||||
rng_clf = RandomState(8)
|
||||
for X, y, clf in zip(
|
||||
(rng_reg.random_sample((5, 2)),
|
||||
rng_clf.random_sample((1000, 4))),
|
||||
(rng_reg.random_sample((5, )),
|
||||
rng_clf.randint(2, size=(1000, ))),
|
||||
(DecisionTreeRegressor(criterion="friedman_mse", random_state=0,
|
||||
max_depth=1),
|
||||
DecisionTreeClassifier(max_depth=1, random_state=0))):
|
||||
|
||||
clf.fit(X, y)
|
||||
for precision in (4, 3):
|
||||
dot_data = export_graphviz(clf, out_file=None, precision=precision,
|
||||
proportion=True)
|
||||
|
||||
# With the current random state, the impurity and the threshold
|
||||
# will have the number of precision set in the export_graphviz
|
||||
# function. We will check the number of precision with a strict
|
||||
# equality. The value reported will have only 2 precision and
|
||||
# therefore, only a less equal comparison will be done.
|
||||
|
||||
# check value
|
||||
for finding in finditer(r"value = \d+\.\d+", dot_data):
|
||||
assert (
|
||||
len(search(r"\.\d+", finding.group()).group()) <=
|
||||
precision + 1)
|
||||
# check impurity
|
||||
if is_classifier(clf):
|
||||
pattern = r"gini = \d+\.\d+"
|
||||
else:
|
||||
pattern = r"friedman_mse = \d+\.\d+"
|
||||
|
||||
# check impurity
|
||||
for finding in finditer(pattern, dot_data):
|
||||
assert (len(search(r"\.\d+", finding.group()).group()) ==
|
||||
precision + 1)
|
||||
# check threshold
|
||||
for finding in finditer(r"<= \d+\.\d+", dot_data):
|
||||
assert (len(search(r"\.\d+", finding.group()).group()) ==
|
||||
precision + 1)
|
||||
|
||||
|
||||
def test_export_text_errors():
|
||||
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
|
||||
err_msg = "max_depth bust be >= 0, given -1"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, max_depth=-1)
|
||||
err_msg = "feature_names must contain 2 elements, got 1"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, feature_names=['a'])
|
||||
err_msg = "decimals must be >= 0, given -1"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, decimals=-1)
|
||||
err_msg = "spacing must be > 0, given 0"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, spacing=0)
|
||||
|
||||
|
||||
def test_export_text():
|
||||
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
|
||||
expected_report = dedent("""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- class: 1
|
||||
""").lstrip()
|
||||
|
||||
assert export_text(clf) == expected_report
|
||||
# testing that leaves at level 1 are not truncated
|
||||
assert export_text(clf, max_depth=0) == expected_report
|
||||
# testing that the rest of the tree is truncated
|
||||
assert export_text(clf, max_depth=10) == expected_report
|
||||
|
||||
expected_report = dedent("""
|
||||
|--- b <= 0.00
|
||||
| |--- class: -1
|
||||
|--- b > 0.00
|
||||
| |--- class: 1
|
||||
""").lstrip()
|
||||
assert export_text(clf, feature_names=['a', 'b']) == expected_report
|
||||
|
||||
expected_report = dedent("""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- weights: [3.00, 0.00] class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- weights: [0.00, 3.00] class: 1
|
||||
""").lstrip()
|
||||
assert export_text(clf, show_weights=True) == expected_report
|
||||
|
||||
expected_report = dedent("""
|
||||
|- feature_1 <= 0.00
|
||||
| |- class: -1
|
||||
|- feature_1 > 0.00
|
||||
| |- class: 1
|
||||
""").lstrip()
|
||||
assert export_text(clf, spacing=1) == expected_report
|
||||
|
||||
X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]]
|
||||
y_l = [-1, -1, -1, 1, 1, 1, 2]
|
||||
clf = DecisionTreeClassifier(max_depth=4, random_state=0)
|
||||
clf.fit(X_l, y_l)
|
||||
expected_report = dedent("""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- truncated branch of depth 2
|
||||
""").lstrip()
|
||||
assert export_text(clf, max_depth=0) == expected_report
|
||||
|
||||
X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
|
||||
y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]]
|
||||
|
||||
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
|
||||
reg.fit(X_mo, y_mo)
|
||||
|
||||
expected_report = dedent("""
|
||||
|--- feature_1 <= 0.0
|
||||
| |--- value: [-1.0, -1.0]
|
||||
|--- feature_1 > 0.0
|
||||
| |--- value: [1.0, 1.0]
|
||||
""").lstrip()
|
||||
assert export_text(reg, decimals=1) == expected_report
|
||||
assert export_text(reg, decimals=1, show_weights=True) == expected_report
|
||||
|
||||
X_single = [[-2], [-1], [-1], [1], [1], [2]]
|
||||
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
|
||||
reg.fit(X_single, y_mo)
|
||||
|
||||
expected_report = dedent("""
|
||||
|--- first <= 0.0
|
||||
| |--- value: [-1.0, -1.0]
|
||||
|--- first > 0.0
|
||||
| |--- value: [1.0, 1.0]
|
||||
""").lstrip()
|
||||
assert export_text(reg, decimals=1,
|
||||
feature_names=['first']) == expected_report
|
||||
assert export_text(reg, decimals=1, show_weights=True,
|
||||
feature_names=['first']) == expected_report
|
||||
|
||||
|
||||
def test_plot_tree_entropy(pyplot):
|
||||
# mostly smoke tests
|
||||
# Check correctness of export_graphviz for criterion = entropy
|
||||
clf = DecisionTreeClassifier(max_depth=3,
|
||||
min_samples_split=2,
|
||||
criterion="entropy",
|
||||
random_state=2)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
feature_names = ['first feat', 'sepal_width']
|
||||
nodes = plot_tree(clf, feature_names=feature_names)
|
||||
assert len(nodes) == 3
|
||||
assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 1.0\n"
|
||||
"samples = 6\nvalue = [3, 3]")
|
||||
assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
|
||||
assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"
|
||||
|
||||
|
||||
def test_plot_tree_gini(pyplot):
|
||||
# mostly smoke tests
|
||||
# Check correctness of export_graphviz for criterion = gini
|
||||
clf = DecisionTreeClassifier(max_depth=3,
|
||||
min_samples_split=2,
|
||||
criterion="gini",
|
||||
random_state=2)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
feature_names = ['first feat', 'sepal_width']
|
||||
nodes = plot_tree(clf, feature_names=feature_names)
|
||||
assert len(nodes) == 3
|
||||
assert nodes[0].get_text() == ("first feat <= 0.0\ngini = 0.5\n"
|
||||
"samples = 6\nvalue = [3, 3]")
|
||||
assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]"
|
||||
assert nodes[2].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"
|
||||
|
||||
|
||||
# FIXME: to be removed in 0.25
|
||||
def test_plot_tree_rotate_deprecation(pyplot):
|
||||
tree = DecisionTreeClassifier()
|
||||
tree.fit(X, y)
|
||||
# test that a warning is raised when rotate is used.
|
||||
match = ("'rotate' has no effect and is deprecated in 0.23. "
|
||||
"It will be removed in 0.25.")
|
||||
with pytest.warns(FutureWarning, match=match):
|
||||
plot_tree(tree, rotate=True)
|
||||
|
||||
|
||||
def test_not_fitted_tree(pyplot):
|
||||
|
||||
# Testing if not fitted tree throws the correct error
|
||||
clf = DecisionTreeRegressor()
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_tree(clf)
|
|
@ -0,0 +1,52 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from sklearn.tree._reingold_tilford import buchheim, Tree
|
||||
|
||||
simple_tree = Tree("", 0,
|
||||
Tree("", 1),
|
||||
Tree("", 2))
|
||||
|
||||
bigger_tree = Tree("", 0,
|
||||
Tree("", 1,
|
||||
Tree("", 3),
|
||||
Tree("", 4,
|
||||
Tree("", 7),
|
||||
Tree("", 8)
|
||||
),
|
||||
),
|
||||
Tree("", 2,
|
||||
Tree("", 5),
|
||||
Tree("", 6)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)])
|
||||
def test_buchheim(tree, n_nodes):
|
||||
def walk_tree(draw_tree):
|
||||
res = [(draw_tree.x, draw_tree.y)]
|
||||
for child in draw_tree.children:
|
||||
# parents higher than children:
|
||||
assert child.y == draw_tree.y + 1
|
||||
res.extend(walk_tree(child))
|
||||
if len(draw_tree.children):
|
||||
# these trees are always binary
|
||||
# parents are centered above children
|
||||
assert draw_tree.x == (draw_tree.children[0].x
|
||||
+ draw_tree.children[1].x) / 2
|
||||
return res
|
||||
|
||||
layout = buchheim(tree)
|
||||
coordinates = walk_tree(layout)
|
||||
assert len(coordinates) == n_nodes
|
||||
# test that x values are unique per depth / level
|
||||
# we could also do it quicker using defaultdicts..
|
||||
depth = 0
|
||||
while True:
|
||||
x_at_this_depth = [node[0] for node in coordinates
|
||||
if node[1] == depth]
|
||||
if not x_at_this_depth:
|
||||
# reached all leafs
|
||||
break
|
||||
assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth)
|
||||
depth += 1
|
1966
venv/Lib/site-packages/sklearn/tree/tests/test_tree.py
Normal file
1966
venv/Lib/site-packages/sklearn/tree/tests/test_tree.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue