当前位置:网站首页>sklearn.metrics模块模型评价函数
sklearn.metrics模块模型评价函数
2022-07-24 15:17:00 【qq_27390023】
sklearn中有3种不同的API用于评估模型的预测质量。
估算器得分方法:估算器有一个评分方法,为它们被设计用来解决的问题提供一个默认的评价标准。
计分参数:使用交叉验证的模型评价工具(如model_selection.cross_val_score和model_selection.GridSearchCV)依赖于一个内部评分策略。
度量函数:sklearn.metrics模块为特定目的实现了评估预测误差的函数。
模型评价函数示例:
from sklearn import metrics
# 查看模块的函数
dir(metrics)
### 1.Accuracy score
import numpy as np
from sklearn.metrics import accuracy_score
y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]
print(accuracy_score(y_true, y_pred))
print(accuracy_score(y_true, y_pred, normalize=False)) # 正确的预测个数
### 2.Top-k accuracy score
# top_k_accuracy_score函数是accuracy_score的泛化。
# 区别在于,只要真实标签与k个最高预测分数之一相关联,预测就被认为是正确的。
# 准确度_分数是k=1的特殊情况。
import numpy as np
from sklearn.metrics import top_k_accuracy_score
y_true = np.array([0, 1, 2, 2])
y_score = np.array([[0.5, 0.2, 0.2],
[0.3, 0.4, 0.2],
[0.2, 0.4, 0.3],
[0.7, 0.2, 0.1]])
top_k_accuracy_score(y_true, y_score, k=2)
# Not normalizing gives the number of "correctly" classified samples
top_k_accuracy_score(y_true, y_score, k=2, normalize=False)
### 3.confusion_matrix
from sklearn import datasets
from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_validate
from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay
import matplotlib.pyplot as plt
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
cm=confusion_matrix(y_true, y_pred)
print(confusion_matrix(y_true, y_pred))
print(confusion_matrix(y_true, y_pred,normalize='all'))
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.show()
# 二分类
y_true = [0, 0, 0, 1, 1, 1, 1, 1]
y_pred = [0, 1, 0, 1, 0, 1, 0, 1]
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
print(tn, fp, fn, tp)
# A sample toy binary classification dataset
X, y = datasets.make_classification(n_classes=2, random_state=0)
svm = LinearSVC(random_state=0)
def confusion_matrix_scorer(clf, X, y):
y_pred = clf.predict(X)
cm = confusion_matrix(y, y_pred)
return {'tn': cm[0, 0], 'fp': cm[0, 1],
'fn': cm[1, 0], 'tp': cm[1, 1]}
cv_results = cross_validate(svm, X, y, cv=5,
scoring=confusion_matrix_scorer)
# Getting the test set true positive scores
print(cv_results['test_tp'])
# Getting the test set false negative scores
print(cv_results['test_fn'])
print(cv_results['test_tn'])
print(cv_results['test_fp'])
### 4.classification_report
from sklearn.metrics import classification_report
y_true = [0, 1, 2, 2, 0]
y_pred = [0, 0, 2, 1, 0]
target_names = ['class 0', 'class 1', 'class 2']
print(classification_report(y_true, y_pred, target_names=target_names))
### 5. hamming_loss
from sklearn.metrics import hamming_loss
y_pred = [1, 2, 3, 4]
y_true = [2, 2, 3, 4]
hamming_loss(y_true, y_pred)
### 6. Precision, recall and F-measures
from sklearn import metrics
y_pred = [0, 1, 0, 0]
y_true = [0, 1, 0, 1]
metrics.precision_score(y_true, y_pred)
metrics.recall_score(y_true, y_pred)
metrics.f1_score(y_true, y_pred)
metrics.fbeta_score(y_true, y_pred, beta=0.5)
metrics.fbeta_score(y_true, y_pred, beta=1)
metrics.fbeta_score(y_true, y_pred, beta=2)
metrics.precision_recall_fscore_support(y_true, y_pred, beta=0.5)
import numpy as np
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
y_true = np.array([0, 0, 1, 1])
y_scores = np.array([0.1, 0.4, 0.35, 0.8])
precision, recall, threshold = precision_recall_curve(y_true, y_scores)
print(precision)
print(recall)
print(threshold)
print( average_precision_score(y_true, y_scores))
## 多分类
from sklearn import metrics
y_true = [0, 1, 2, 0, 1, 2]
y_pred = [0, 2, 1, 0, 0, 1]
print(metrics.precision_score(y_true, y_pred, average='macro'))
print(metrics.recall_score(y_true, y_pred, average='micro'))
print(metrics.f1_score(y_true, y_pred, average='weighted'))
print(metrics.fbeta_score(y_true, y_pred, average='macro', beta=0.5))
print(metrics.precision_recall_fscore_support(y_true, y_pred, beta=0.5, average=None))
### 7. 回归预测的r2_score
# r2_score函数计算决定系数,通常表示为R²。
# 它表示模型中自变量所解释的方差(Y)的比例。它提供了拟合度的指示,
# 因此通过解释方差的比例来衡量未见过的样本可能被模型预测的程度。
from sklearn.metrics import r2_score
y_true = [3, -0.5, 2, 7]
y_pred = [2.5, 0.0, 2, 8]
r2_score(y_true, y_pred)
### 8. 回归预测的mean_absolute_error
from sklearn.metrics import mean_absolute_error
y_true = [3, -0.5, 2, 7]
y_pred = [2.5, 0.0, 2, 8]
mean_absolute_error(y_true, y_pred)
### 9. 回归预测的mean_squared_error
from sklearn.metrics import mean_squared_error
y_true = [3, -0.5, 2, 7]
y_pred = [2.5, 0.0, 2, 8]
mean_squared_error(y_true, y_pred)
### 10. 回归预测的mean_squared_log_error
from sklearn.metrics import mean_squared_log_error
y_true = [3, 5, 2.5, 7]
y_pred = [2.5, 5, 4, 8]
mean_squared_log_error(y_true, y_pred)
### 11.无监督聚类的Silhouette Coefficient
from sklearn import metrics
from sklearn import datasets
import numpy as np
X, y = datasets.load_iris(return_X_y=True)
from sklearn.cluster import KMeans
kmeans_model = KMeans(n_clusters=3, random_state=1).fit(X)
labels = kmeans_model.labels_
metrics.silhouette_score(X, labels, metric='euclidean')参考:
https://scikit-learn.org/stable/modules/model_evaluation.html#multilabel-ranking-metrics
https://scikit-learn.org/stable/modules/clustering.html#clustering-evaluation
边栏推荐
- File upload and download and conversion between excel and data sheet data
- Machine learning practice notes
- Error when using Fiddler hook: 502 Fiddler - connection failed
- Comparison of traversal speed between map and list
- DS diagram - the shortest path of the diagram (excluding the code framework)
- 佣金哪家券商最低,我要开户,手机上开户安不安全
- Google Earth Engine——使用MODIS数据进行逐月数据的过火(火灾)面积并导出
- Outlook tutorial, how to set rules in outlook?
- 2022 RoboCom 世界机器人开发者大赛-本科组(省赛)-- 第五题 树与二分图 (已完结)
- PIP source switching
猜你喜欢
![[bug solution] error in installing pycocotools in win10](/img/91/4d0ed64738656a6f406f760d6bece3.png)
[bug solution] error in installing pycocotools in win10

Detailed explanation of document operation

【Bug解决】Win10安装pycocotools报错

Outlook tutorial, how to create tasks and to DOS in outlook?

Kotlin类与继承

27.目录与文件系统

Use of keywords const, volatile and pointer; Assembly language and view of register status

华为相机能力

Intuitive understanding of various normalization

Summary of feature selection: filtered, wrapped, embedded
随机推荐
老虎口瀑布:铜梁版小壶口瀑布
Error when using Fiddler hook: 502 Fiddler - connection failed
循环结构practice
Outlook tutorial, how to create tasks and to DOS in outlook?
Performance test - Preparation of test plan
(09) flask is OK if it has hands - cookies and sessions
Performance test - Test Execution
Date processing bean
Leetcode high frequency question 56. merge intervals, merge overlapping intervals into one interval, including all intervals
Intelligent operation and maintenance scenario analysis: how to detect abnormal business system status through exception detection
PIP source switching
DS binary tree - maximum distance of binary tree nodes
Which securities company is the best and safest to open an account? How to open an account and speculate in stocks
LeetCode高频题56. 合并区间,将重叠的区间合并为一个区间,包含所有区间
MySQL function
C# 无操作则退出登陆
26.文件使用磁盘的代码实现
Leetcode 1288. delete the covered interval (yes, solved)
[bug solution] error in installing pycocotools in win10
Kotlin class and inheritance