当前位置:网站首页>机器学习--决策树(sklearn)
机器学习--决策树(sklearn)
2022-07-06 09:16:00 【想成为风筝】
机器学习–决策树(sklearn)
决策树是基于树结构来对实例进行决策的一种基本的分类和回归的机器学习方法。决策树由结点和有向边组成,结点分为内部结点(表示一个特征的划分)和叶子结点(表示一个类别或输出)。
决策树学习,训练数据集
𝐷 = ( 𝐱 1 , 𝑦 1 ) , ( 𝐱 2 , 𝑦 2 ) , ⋯ , ( 𝐱 𝑖 , 𝑦 𝑖 ) , ⋯ , ( 𝐱 𝑁 , 𝑦 𝑁 ) 𝐷 = {(𝐱1, 𝑦1) , (𝐱2, 𝑦2) , ⋯ , (𝐱𝑖, 𝑦𝑖) , ⋯ , (𝐱𝑁, 𝑦𝑁)} D=(x1,y1),(x2,y2),⋯,(xi,yi),⋯,(xN,yN)
其中, x i xi xi为第 个特征向量(实例),
𝐱 𝑖 = ( 𝑥 ( 𝑖 1 ) , 𝑥 ( 𝑖 2 ) , … , 𝑥 ( 𝑖 𝑗 ) , … , 𝑥 ( 𝑖 𝑛 ) ) 𝑇 𝐱𝑖 = (𝑥( 𝑖1), 𝑥( 𝑖2), … , 𝑥( 𝑖𝑗), … , 𝑥( 𝑖𝑛))^𝑇 xi=(x(i1),x(i2),…,x(ij),…,x(in))T , y i yi yi为 x i xi xi的类别标记,
𝑦 𝑖 ∈ 1 , 2 , ⋯ , 𝐾 𝑦𝑖 ∈ {1, 2, ⋯ , 𝐾} yi∈1,2,⋯,K。
学习的⽬标数是根据给定的训练数据集构建⼀个决策树模型,使得可对实例进⾏正确的分类或回归。
决策树学习包括3个步骤:特征选择、决策树⽣成、决策树修剪。
一、特征选择
特征选择在于选取对于训练数据具有分类能力的特征。
熵表示随机变量不确定性的度量。(更详细的概念理解,查阅信息论、通信原理基础。)
设 X X X是一个取有限个值的离散随机变量,其概率分布为
P ( X = x i ) = p i , i = 1 , 2 , , . . . , n P(X=xi)=pi,i=1,2,,...,n P(X=xi)=pi,i=1,2,,...,n
则随机变量 X X X的熵
H ( X ) = H ( p ) = − ∑ p i l o g 𝑝 𝑖 H(X)=H(p)= − ∑pi log 𝑝𝑖 H(X)=H(p)=−∑pilogpi
H ( p ) H(p) H(p)的取值范围: 0 < H ( p ) < l o g n 0<H(p)<log n 0<H(p)<logn,当 p = 0.5 p=0.5 p=0.5时,取得最大值。
信息增益算法如下图所示:
以信息增益为特征选择标准偏向取值较多的特征。当特征的取值较多时,根据此特征划分更容易得到纯度更高的子集,因此划分之后的熵更低。由于划分前的熵是一定的,因此信息增益更大。
二、决策树生成
ID3算法:
输入:训练数据集 D D D,特征集合 A A A,阈值 ϵ \epsilon ϵ
输出:决策树 T T T
C4.5算法:
输入:训练数据集 D D D,特征集合 A A A,阈值 ϵ \epsilon ϵ
输出:决策树 T T T
三、决策树剪枝
决策树的剪枝通过极小化决策树的整体损失函数或代价函数来实现。
设树 T T T的叶结点个数为| T T T|, t t t是树 T T T的叶节点,该叶结点有 N t N_t Nt个, k = 1 , 2 , . . . , K k=1,2,...,K k=1,2,...,K,
H t ( T ) H_t(T) Ht(T)为叶节点 t t t上的经验熵,则决策树的损失函数:
其中, C ( T ) C(T) C(T)表示模型对训练数据的预测误差,即模型与训练数据的拟合程度, ∣ T ∣ |T| ∣T∣表示模型复杂度,参数 α > = 0 α>=0 α>=0控制两者之间的影响。
树的剪枝算法:
输入:决策树 T T T,参数 α α α
输出:修建后的子树 T α T_α Tα
- 计算每个结点的经验熵
- 递归地从树的叶结点向上回缩。设一组叶结点回缩到其父结点之前与之后的整体树分别为 T B 与 T A T_B与T_A TB与TA,其对于的损失函数值分别是 C α ( T B ) 与 C α ( T A ) C_α(T_B)与C_α(T_A) Cα(TB)与Cα(TA),如果 C α ( T A ) < = C α ( T B ) C_α(T_A) <= C_α(T_B) Cα(TA)<=Cα(TB)。则进行剪枝,即父结点变为新的叶结点。
- 返回2.,直到不能继续为止,得到损失函数最小的子树 T α T_α Tα。
四、分类与回归树CART
4.1回归树的生成
4.2分类树的生成
4.3CART树剪枝
对整体树 T 0 T_0 T0任意内部结点 t t t,以 t t t为单结点树的损失函数
C α ( t ) = C ( t ) + α C_α(t) = C(t)+α Cα(t)=C(t)+α
以 t t t为根结点的子树 T t T_t Tt的损失函数
C α ( T t ) = C ( T t ) + α ∣ T t ∣ C_α(T_t) = C(T_t)+α|T_t| Cα(Tt)=C(Tt)+α∣Tt∣
决策树的实现-sklearn库调用方式及参数解释
决策树(分类树)
from sklearn.tree import DecisionTreeClassifier
DecisionTreeClassifier(criterion='gini',splitter='best',max_depth=None,
min_samples_split=2,min_samples_leaf=1,
min_weight_fraction_leaf=0.0,max_features=None,
random_state=None,max_leaf_nodes=None,
min_impurity_decrease=0.0,min_impurity_split=1e-07,
class_weight=None, presort=False)***
参数:
- criterion : 一个字符串,指定切分质量的评价标准。可以为
‘gini’ :表示切分标准是Gini系数。切分时选取基尼系数小的属性切分。
‘entropy’ : 表示切分标准是熵。 - splitter : 一个字符串,指定切分原则,可以为:
best : 表示选择最优的切分。
random : 表示随机切分。
默认的"best"适合样本量不大的时候,而如果样本数据量非常大,此时决策树构建推荐"random"。 - max_features : 可以为整数、浮点、字符或者None,指定寻找best split时考虑的特征数量。
如果是整数,则每次切分只考虑max_features个特征。
如果是浮点数,每次切分只考虑max_features*n_features个特征(max_features指定百分比)。
如果是字符串‘auto’,则max_features等于n_features。
如果是字符串‘sqrt’,则max_features等于sqrt(n_features)。
如果是字符串‘log2’,则max_features等于log2(n_features)。
如果是字符串None,则max_features等于n_features。 - max_depth : 可以为整数或者None,指定树的最大深度,防止过拟合
如果为None,表示树的深度不限(知道每个叶子都是纯的,即叶子结点中的所有样本点都属于一个类,
或者叶子中包含小于min_sanples_split个样本点)。
如果max_leaf_nodes参数非None,则忽略此项。 - min_samples_split : 为整数,指定每个内部节点(非叶子节点)包含的最少的样本数。
- min_samples_leaf : 为整数,指定每个叶子结点包含的最少的样本数。
- min_weight_fraction_leaf : 为浮点数,叶子节点中样本的最小权重系数。
- max_leaf_nodes : 为整数或None,指定叶子结点的最大数量。
如果为None,此时叶子节点数不限。如果非None,则max_depth被忽略。 - min_impurity_decrease=0.0 如果该分裂导致不纯度的减少大于或等于该值,则将分裂节点。
- min_impurity_split=1e-07, 限制决策树的增长,
- class_weight : 一个字典、字典的列表、字符串‘balanced’或者None,他指定了分类的权重。
权重形式为:{class_label:weight} 如果为None,则每个分类权重都为1.
字符串‘balanced’表示每个分类的权重是各分类在样本出现的频率的反比。 - random_state : 一个整数或者一个RandomState实例,或者None。
如果为整数,则它指定了随机数生成器的种子。
如果为RandomState实例,则指定了随机数生成器。
如果为None,则使用默认的随机数生成器。 - presort : 一个布尔值,指定了是否要提前排序数据从而加速寻找最优切分的过程。
设置为True时,对于大数据集会减慢总体训练过程,但对于小数据集或者设定了最大深度的情况下,则会加速训练过程。
属性:
- classes_ : 分类的标签值。
- feature_importances_ : 给出了特征的重要程度。该值越高,则特征越重要(也称为Gini importance)。
- max_features_ : max_feature的推断值。
- n_classes_ : 给出了分类的数量。
- n_features_ : 当执行fit后,特征的数量。
- n_outputs_ : 当执行fit后,输出的数量。
- tree_ : 一个Tree对象,即底层的决策树。
方法: - fit(X,y) : 训练模型。
- predict(X) : 用模型预测,返回预测值。
- predict_log_proba(X) : 返回一个数组,数组元素依次为X预测为各个类别的概率值的对数值。
- predict_proba(X) : 返回一个数组,数组元素依次为X预测为各个类别的概率
- score(X,y) : 返回在(X,y)上预测的准确率(accuracy)。
决策树(回归树)
from sklearn.tree import DecisionTreeRegressor
DecisionTreeRegressor(criterion='mse',splitter='best',max_depth=None,min_samples
_split=2,min_samples_leaf=1,min_weight_fraction_leaf=0.0,max_features=None,rando
m_state=None,max_leaf_nodes=None,min_impurity_split=1e-07,presort=False)
参数:
- criterion : 一个字符串,指定切分质量的评价标准。默认为‘mse’,且只支持该字符串,表示均方误差。
- splitter : 一个字符串,指定切分原则,可以为:best : 表示选择最优的切分。random : 表示随机切分。
- max_features : 可以为整数、浮点、字符或者None,指定寻找best split时考虑的特征数量。如果是整数,则每次切分只考虑max_features个特征。如果是浮点数,则每次切分只考虑max_features*n_features个特征(max_features指定了百分比)。
如果是字符串‘auto’,则max_features等于n_features。
如果是字符串‘sqrt’,则max_features等于sqrt(n_features)。
如果是字符串‘log2’,则max_features等于log2(n_features)。
如果是字符串None,则max_features等于n_features。 - max_depth : 可以为整数或者None,指定树的最大深度。
如果为None,表示树的深度不限(知道每个叶子都是纯的,即叶子结点中的所有样本点都属于一个类,或者叶子中包含小于min_sanples_split个样本点)。如果max_leaf_nodes参数非None,则忽略此项。 - min_samples_split : 为整数,指定每个内部节点(非叶子节点)包含的最少的样本数。
- min_samples_leaf : 为整数,指定每个叶子结点包含的最少的样本数。
- min_weight_fraction_leaf : 为浮点数,叶子节点中样本的最小权重系数。
- max_leaf_nodes : 为整数或None,指定叶子结点的最大数量。如果为None,此时叶子节点数不限。如果非None,则max_depth被忽略。
- class_weight : 一个字典、字典的列表、字符串‘balanced’或者None,它指定了分类的权重。权重形式为:{class_label:weight} 如果为None,则每个分类权重都为1.字符串‘balanced’表示每个分类的权重是各分类在样本出现的频率的反比。
- random_state : 一个整数或者一个RandomState实例,或者None。如果为整数,则它指定了随机数生成器的种子。如果为RandomState实例,则指定了随机数生成器。如果为None,则使用默认的随机数生成器。
- presort : 一个布尔值,指定了是否要提前排序数据从而加速寻找最优切分的过程。
设置为True时,对于大数据集会减慢总体训练过程,但对于小数据集或者设定了最大深度的情况下,则会加速训练过程。
属性:
- feature_importances_ : 给出了特征的重要程度。该值越高,则特征越重要(也称为Gini importance)。
- max_features_ : max_feature的推断值。
- n_features_ : 当执行fit后,特征的数量。
- n_outputs_ : 当执行fit后,输出的数量。
- tree_ : 一个Tree对象,即底层的决策树。
方法:
- fit(X,y) : 训练模型。
- predict(X) : 用模型预测,返回预测值。
- score(X,y) : 返回性能得分
鸢尾花分类(决策树案例,直接调用sklearn)
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def creat_data():
iris = load_iris()
df = pd.DataFrame(iris.data,columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
data = np.array(df.iloc[:100,[0,1,-1]])
return data[:,:2],data[:,-1]
X,y = creat_data()
X_train,X_test ,y_train,y_test = train_test_split(X,y)
clf = DecisionTreeClassifier()
clf.fit(X_train,y_train)
print(clf.score(X_test,y_test))
print(clf.feature_importances_) #k查看特征重要性程度
#其他属性均可查看
tree_pic = export_graphviz(clf,out_file='Tree.pdf')
with open("Tree.pdf") as f:
dot_graph = f.read()
graphviz.Source(dot_graph)
画出下图需要使用graphviz。安装比较麻烦!
边栏推荐
- [Blue Bridge Cup 2017 preliminary] buns make up
- Contiki source code + principle + function + programming + transplantation + drive + network (turn)
- Mysql的索引实现之B树和B+树
- Those commonly used tool classes and methods in hutool
- vs2019 第一个MFC应用程序
- Integration test practice (1) theoretical basis
- 4. Install and deploy spark (spark on Yan mode)
- DICOM: Overview
- {one week summary} take you into the ocean of JS knowledge
- 【yarn】CDP集群 Yarn配置capacity调度器批量分配
猜你喜欢
How to build a new project for keil5mdk (with super detailed drawings)
wangeditor富文本引用、表格使用问题
About string immutability
Stage 4 MySQL database
{一周总结}带你走进js知识的海洋
vs2019 第一个MFC应用程序
Solve the problem of installing failed building wheel for pilot
Word排版(小计)
FTP文件上传文件实现,定时扫描文件夹上传指定格式文件文件到服务器,C语言实现FTP文件上传详解及代码案例实现
[Flink] Flink learning
随机推荐
[yarn] CDP cluster yarn configuration capacity scheduler batch allocation
Word排版(小計)
AcWing 1294. Cherry Blossom explanation
[Bluebridge cup 2020 preliminary] horizontal segmentation
Aborted connection 1055898 to db:
互联网协议详解
Word排版(小计)
SQL时间注入
【flink】flink学习
MySQL与c语言连接(vs2019版)
[AGC009D]Uninity
Contiki源码+原理+功能+编程+移植+驱动+网络(转)
MySQL and C language connection (vs2019 version)
C语言读取BMP文件
Gallery之图片浏览、组件学习
Composition des mots (sous - total)
天梯赛练习集题解LV1(all)
Reading BMP file with C language
[蓝桥杯2021初赛] 砝码称重
Solve the problem of installing failed building wheel for pilot