当前位置:网站首页>决策树、GBDT、XGBOOST树的可视化

决策树、GBDT、XGBOOST树的可视化

2022-08-03 21:51:00 `AllureLove

决策树

# 决策树单颗树的可视化
from sklearn import tree
import pydotplus

# 假设已经训练好的决策树模型对象为decision_tree_model
# 决策树可视化存储
dot_data = tree.export_graphviz(decision_tree_model, out_file=None, filled=True, rounded=True, special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png("decision_tree.png")

GBDT

from sklearn import tree
import pydotplus


def plot_tree(model, n_tree=0, feature_name_list):
    """ 绘制第n颗树结构 @param model: 模型 @param n_tree: 第几棵树,起始编号为0 @param feature_name_list: 用到的特征名称列表 @return: """
    # 从存储树对象的数列中获取第n颗树
    # sklearn中的其他集成策略应该也可以通过此方式来获取子树
    sub_tree = model.estimators_[n_tree, 0]
    dot_data = tree.export_graphviz(sub_tree, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=feature_name_list)
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.write_png("{}_tree.png".format(n_tree))

# 绘制树,假设已经训练好的GBDT模型为gbdt_model,假设特征名称列表为feature_name_list
plot_tree(gbdt_model, n_tree=1, feature_name_list)

XGBOOST

import xgboost as xgb
import matplotlib.pyplot as plt

def create_feature_map(feature_name_list, save_path=None):
    """ 创建特征映射 @parameter feature_name_list: 特征名称列表 @return: """
    save_path = "/".join([save_path, "xgb.fmap"])
    with open(save_path, "w") as f:
        i = 0
        for feat in feature_name_list:
            f.write('{0}\t{1}\tq\n'.format(i, feat))
            i = i + 1
    return save_path
    
def plot_tree(model_path, num_trees=0, feature_name_list)):
    """ 绘制树结构 num_trees表示树的序号 @return: """
    # 绘制树结构
    model = xgb.Booster({
    'nthread': 4})
    model.load_model(model_path)
    save_prefix = "/".join(model_path.split("/")[:-1])
    fmap_save_path = create_feature_map(feature_name_list), save_path=save_prefix)
    xgb.plot_tree(model, num_trees=num_trees, fmap=fmap_save_path, rankdir="LR")
    fig = plt.gcf()
    fig.set_size_inches(150, 150)
    tree_save_path = "/".join([save_prefix, "tree.tiff"])
    plt.savefig(tree_save_path)

# 假设已经训练好的模型路径为model_path,假设模型特征名称列表为feature_name_list
plot_tree(model_path, num_trees, feature_name_list)
原网站

版权声明
本文为[`AllureLove]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_36488653/article/details/126129401