Decision Tree

Inspector functionality specific to tree models
import sklearn.datasets
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

from model_inspector import get_inspector

source

_TreeInspector.plot_tree

 _TreeInspector.plot_tree (ax:Optional[matplotlib.axes._axes.Axes]=None,
                           max_depth=None, feature_names=None,
                           class_names=None, label='all', filled=False,
                           impurity=True, node_ids=False,
                           proportion=False, rounded=False, precision=3,
                           fontsize=None)

Show decision tree.

Remaining parameters are passed to sklearn.tree._export.plot_tree.

Regression Example

X, y = sklearn.datasets.load_diabetes(return_X_y=True, as_frame=True)
inspector = get_inspector(DecisionTreeRegressor(max_depth=3).fit(X, y), X, y)
ax = inspector.plot_tree()

Binary Classification Example

X, y = sklearn.datasets.load_breast_cancer(return_X_y=True, as_frame=True)
inspector = get_inspector(DecisionTreeClassifier(max_depth=3).fit(X, y), X, y)
ax = inspector.plot_tree()

Multiclass Example

X, y = sklearn.datasets.load_iris(return_X_y=True, as_frame=True)
dtr = DecisionTreeClassifier(max_depth=3).fit(X, y)
inspector = get_inspector(dtr, X, y)
ax = inspector.plot_tree()