当前位置:网站首页>【机器学习】用特征量重要度(feature importance)解释模型靠谱么?怎么才能算出更靠谱的重要度?
【机器学习】用特征量重要度(feature importance)解释模型靠谱么?怎么才能算出更靠谱的重要度?
2022-07-31 08:43:00 【hongxu000】
【机器学习】用特征量重要度(feature importance)解释模型靠谱么?怎么才能算出更靠谱的重要度?
我们用机器学习解决商业问题的时候,不仅需要训练一个高精度高泛化性的模型,往往还需要解释哪些因素或特征影响了预测结果。比如预测保险客户是否会解约(customer churn),我们不仅要知道谁会解约,我们还想知道他为啥解约。最常用的方式,便是查看特征量的重要度。
(*这里的特征量指用来预测结果的要因feature/predictor/independent variable,比如用客户这周的搜索关键词预测下周会不会买游戏机,搜索关键词就是特征量,下周会不会买游戏机就是预测结果)
特征量重要度的计算一般取决于用什么算法,如果是以决定树为基础(tree-based)的集成算法,比如随机森林,lightGBM之类的,一般都是取impurity(gini)平均下降幅度最大的一些特征,python的 Scikit-learn里面就有命令可以计算。
像这样:
from sklearn.ensemble import RandomForestClassifier
rf = RandomForestClassifier(
n_estimators=100,
n_jobs=-1,
min_samples_leaf = 5,
oob_score=True,
random_state = 42)
rf.fit(X_train, y_train)
feat_importances = pd.Series(rf.feature_importances_, index = X_train.columns).sort_values(ascending = True)
结果例:
但这个特征量重要度一定靠谱么?
1. 如果模型精度不好,或者泛化性不好(overfit),那特征量重要度不能轻信。
模型精度不好,当然有很多原因,有可能学习不够,数据不够。也有可能是现有的特征量无法很好的预测结果,有更重要的特征量缺失了。这时候看特征量的重要度就好像是矬子里拔大个儿,没啥太大意思。
泛化性不好,一般是模型过度学习了训练数据的一些噪音,把噪音当规律了,这时候一些本来无关紧要的噪音特征量也会登上重要度榜首,所以这时单纯相信特征量就很容易跑偏,错过重要的特征。
另外,这个用impurity计算的重要度会有bias,会倾向于把重要度给数字型特征(numerical),而忽视类别特征(categorical or binary) 论文:https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-8-25
那咋办呢?
当然最有用的是把你的模型调教好,不要有泛化性不好等问题,那就可以相信特征量重要度的结果了。
但对于impurity计算的重要度的bias,还有一个其他的方法,可以计算更靠谱的特征量。
单纯的想一下,我想知道一个特征是不是重要特征,只要把它删掉之后重新训练模型,看精度有没有下降,下降多少,就能知道它有多重要,和分手之后知道才知道她有多重要是一个原理不是= =
这完全可以自己写python代码来实现,但SikitLearn提供了一个类似的算法:permutation_importance。区别是,permutation_importance不会删掉特征,而是重新排列(shuffule)这个列的值之后再训练模型,效果上应该差不太多。历遍所有特征量,你就得到了一份更加靠谱的重要度。
计算permutation_importance
from sklearn.inspection import permutation_importance
#calculate permutation importance for test data
result_test = permutation_importance(
rf, X_test, y_test, n_repeats=20, random_state=42, n_jobs=2
)
sorted_importances_idx_test = result_test.importances_mean.argsort()
importances_test = pd.DataFrame(
result_test.importances[sorted_importances_idx_test].T,
columns=X.columns[sorted_importances_idx_test],
)
算完是这样
也许有人会问,那这种情况就不会因为overfit而导致特征量重要度不靠谱了么?
我个人的理解是,因为这个结果不是某一个模型的结果,而是训练了和特征量相同数量模型的结果,而且比较了有噪音数据和无噪音数据的情况,一定程度上排除了overfit的影响。
2. 如果模型里存在2个以上有线性关系(co-linear)的特征量,需要注意!
在线性回归linear regression里,我们会强调,不能有两个特征量有线性关系,需要删掉一个。
但在一个tree base的ensemble模型,貌似没这个要求,我们一般会把所有有关系的特征量全部丢进去,然后再通过特征量重要度选择最终模型需要的特征。
那是不是tree-base模型不用考虑共线问题了?(共线co-linear指有2个或以上特征量有线性关系)
曾经我也这样认为,但其实在理解特征量重要度的时候是需要考虑的共线问题的,因为重要度会被有线性关系的几个特征量瓜分,它们可能一人分一些重要度,那么在重要度榜单上可能就排后面了。你可能以为它们没那么重要,但事实并不如此。
那咋办呢
建议合并这些有线性关系的特征量,或者如果其中一个特征量可以完全被其他代表dependency=1,也可以删掉这个特征量。
以下是如何确认线性关系:
from rfpimp import plot_corr_heatmap
viz = plot_corr_heatmap(X_train, figsize=(7,5))
viz.view()
以下是如何确认特征之间的dependency:
dependence_matrix = feature_dependence_matrix(X_train,
rfrmodel=RandomForestRegressor(n_estimators=50, oob_score=True),
rfcmodel=RandomForestClassifier(n_estimators=50, oob_score=True),
cat_count=20,
zero=0.001,
sort_by_dependence=False,
n_samples=5000)
dependence_matrix
plot_dependence_heatmap(dependence_matrix)
总之关于特征量重要度,不是算出来就完事了。要保持清醒和谨慎,才能得到真理~
就写到这里~ 欢迎围观讨论指导~
边栏推荐
- 【C#】判断字符串中是否包含指定字符或字符串(Contains/IndexOf)
- How on one machine (Windows) to install two MYSQL database
- 刷题《剑指Offer》day06
- SQLAlchemy使用教程
- 基于golang的swagger超贴心、超详细使用指南【有很多坑】
- 【Unity】编辑器扩展-04-拓展Scene视图
- A brief introduction to the SSM framework
- Cloud server deployment web project
- 【Unity】编辑器扩展-01-拓展Project视图
- 编译器R8问题Multidex
猜你喜欢
How to Install MySQL on Linux
【MySQL功法】第2话 · 数据库与数据表的基本操作
[MySQL exercises] Chapter 2 Basic operations of databases and data tables
科目三:左转弯
【RISC-V】risc-v架构学习笔记(架构初学)
ONES 入选 CFS 财经峰会「2022数字化创新引领奖」
How to restore data using mysql binlog
SQL 入门之第一讲——MySQL 8.0.29安装教程(windows 64位)
状态机动态规划之股票问题总结
Docker-compose安装mysql
随机推荐
【MySQL功法】第4话 · 和kiko一起探索MySQL中的运算符
【pytorch记录】pytorch的分布式 torch.distributed.launch 命令在做什么呢
2019 NeurIPS | Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation
免安装版的Mysql安装与配置——详细教程
JSP application对象简介说明
【小程序项目开发-- 京东商城】uni-app之自定义搜索组件(下) -- 搜索历史
如何使用mysql binlog 恢复数据
SSM integration case study (detailed)
SQL 入门之第一讲——MySQL 8.0.29安装教程(windows 64位)
云服务器部署 Web 项目
模块化规范
服务器上解压文件时提示“gzip: stdin: not in gzip format,tar: Child returned status 1,tar: Error is not recovera“
【小程序项目开发--京东商城】uni-app之自定义搜索组件(上)-- 组件UI
剑指offer-解决面试题的思路
The torch distributed training
蚂蚁核心科技产品亮相数字中国建设峰会 持续助力企业数字化转型
MySQL安装常见报错处理大全
文件管理:目录管理
C# 正则表达式汇总
51单片机-----外部中断