Ben Chuanlong Du's Blog

It is never too late to learn.

Visualization of GBDT in scikit-learn

Things on this page are fragmentary and immature notes/thoughts of the author. Please read with your own judgement!

In [ ]:
 
In [ ]:
fig, axes = plt.subplots(figsize=(50, 5))
tree.plot_tree(classifier3["decisiontreeclassifier"],
                feature_names=classifier3[:-1].get_feature_names_out(),
                class_names=["Good", "Bad"],
                rounded=True,
                precision=1,
                filled=True,
                impurity=False,
                fontsize=10)
In [ ]:
from sklearn import tree

for dt_estimator in classifier["gradientboostingclassifier"].estimators_:
    dt = dt_estimator[0]
    fig, axes = plt.subplots(figsize=(50, 5))
    tree.plot_tree(dt,
                 feature_names=classifier[:-1].get_feature_names_out(),
                 class_names=["Good", "Bad"],
                 rounded=True,
                 precision=1,
                 filled=True,
                 impurity=False,
                 fontsize=10)

References

How to visualize an sklearn GradientBoostingClassifier? https://stackoverflow.com/questions/44974360/how-to-visualize-an-sklearn-gradientboostingclassifier

Comments