当前位置:网站首页>ResNet50+k折交叉验证+数据增强+画图(准确率、召回率、F值)
ResNet50+k折交叉验证+数据增强+画图(准确率、召回率、F值)
2022-07-29 01:33:00 【落难Coder】
k折交叉验证
def kfold(data, k=5):
""" K折交叉验证 """
X = np.arange(len(data))
KF = KFold(n_splits=k,shuffle=True)
for train_idxs, valid_idxs in KF.split(X):
train_iter, valid_iter = [] , []
for i in train_idxs:
train_iter.append(data[i])
for i in valid_idxs:
valid_iter.append(data[i])
train_data = torch.utils.data.DataLoader(train_iter, shuffle=True,
batch_size = batch_size)
valid_data = torch.utils.data.DataLoader(valid_iter, batch_size = batch_size)
yield train_data, valid_data
完整代码
from glob import glob
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import shutil
from torchvision import transforms
from torchvision import models
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.optim import lr_scheduler
from torch import optim
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from sklearn.model_selection import KFold
from mpl_toolkits.axes_grid1 import host_subplot
import time
%matplotlib inline
def try_gpu(i=0):
"""如果存在,则返回gpu(i),否则返回cpu()"""
if torch.cuda.device_count() >= i + 1:
return torch.device(f'cuda:{
i}')
return torch.device('cpu')
path = '../data/data-8/train/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total train of images {
len(files)}')
path = '../data/data-8/test/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total valid of images {
len(files)}')
imag_size = 224
batch_size = 16
# 数据增强
transform = transforms.Compose([
transforms.Resize((imag_size, imag_size)),
transforms.RandomHorizontalFlip(),# 随机水平翻转
transforms.RandomVerticalFlip(), # 随机竖直翻转
transforms.RandomRotation(45), # 随机角度旋转
transforms.RandomCrop((imag_size, imag_size)), # 随机位置裁取
transforms.ToTensor(),
# ,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
transforms.Resize((imag_size, imag_size)),
transforms.ToTensor(),
])
train_imgs = ImageFolder('../data/data-8/train', transform)
test_imgs = ImageFolder('../data/data-8/test', transform_test)
def kfold(data, k=5):
""" K折交叉验证 """
X = np.arange(len(data))
KF = KFold(n_splits=k,shuffle=True)
for train_idxs, valid_idxs in KF.split(X):
train_iter, valid_iter = [] , []
for i in train_idxs:
train_iter.append(data[i])
for i in valid_idxs:
valid_iter.append(data[i])
train_data = torch.utils.data.DataLoader(train_iter, shuffle=True,
batch_size = batch_size)
valid_data = torch.utils.data.DataLoader(valid_iter, batch_size = batch_size)
yield train_data, valid_data
def train(data, isTrain=True):
if isTrain:
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in data:
if isTrain:
optimizer.zero_grad()
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if isTrain:
loss.backward()
optimizer.step()
running_loss += loss.data
running_corrects += torch.sum(preds == labels.data)
loss = running_loss / len(data) / batch_size
acc = running_corrects / len(data) / batch_size
return loss, acc
def test(data):
real_lables,pred_lables = [],[]
model.eval()
running_corrects = 0
for inputs, labels in data:
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
for y in labels:
real_lables.append(y.item())
for y in preds:
pred_lables.append(y.item())
running_corrects += torch.sum(preds == labels.data)
acc = running_corrects / len(data) / batch_size
return acc, running_corrects, real_lables, pred_lables
# 设置超参数
train_iterations, train_loss, test_accuracy =[], [], []
model = models.resnet50(pretrained=False)
# model_ft.load_state_dict(torch.load('data/resnet50-19c8e357.pth'))
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 8)
k = 5
lr, num_epochs = 2e-4, 20
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
device = try_gpu(0)
# 开始训练
print('训练开始 on', device)
model.to(device)
for epoch in range(num_epochs):
loss, acc = 0.0, 0.0
for train_data, valid_data in kfold(train_imgs, k): # k折交叉验证
train_losses, train_acc = train(train_data)
valid_losses, valid_acc = train(valid_data, False)
loss += valid_losses
acc += valid_acc
if epoch % 1 == 0: # 每隔10次输出一次结果
train_iterations.append(epoch)
train_loss.append((loss/k).to('cpu').item())
test_accuracy.append((acc/k).to('cpu').item())
print('{},{:.4f},{:.4f}'.format(epoch, loss/k, acc/k))
# print('Epoch {}/{} avgLoss: {:.4f} Acc: {:.4f}'.format(epoch + 1, num_epochs, loss/k, acc/k))
train_iterations.append(num_epochs - 1)
train_loss.append((loss/k).to('cpu').item())
test_accuracy.append((acc/k).to('cpu').item())
print('{},{:.4f},{:.4f}'.format(num_epochs - 1, loss/k, acc/k))
print('训练结束')
torch.save(model, '../model/20220507-vgg16-数据增强-'+str(lr) +'.pth')
print('测试 on', device)
model.to(device)
test_data = torch.utils.data.DataLoader(test_imgs, shuffle=True, batch_size=batch_size)
acc, corrects, real, pre = test(test_data)
print('准确率: {:.4f} 正确预测个数: {}'.format(acc, corrects))
# k = 0
# while k < len(real_lables):
# print(real_lables[k], pred_lables[k])
# k = k + 1
# 绘制曲线图
host = host_subplot(111)
plt.subplots_adjust(right=0.8) # ajust the right boundary of the plot window
par1 = host.twinx()
# 设置类标
host.set_xlabel("iterations")
host.set_ylabel("loss")
par1.set_ylabel("validation accuracy")
# 绘制曲线
p1, = host.plot(train_iterations, train_loss, "b-", label="training loss")
p2, = host.plot(train_iterations, train_loss, ".") #曲线点
p3, = par1.plot(train_iterations, test_accuracy, label="validation accuracy")
p4, = par1.plot(train_iterations, test_accuracy, "1")
# 设置图标
# 1->rightup corner, 2->leftup corner, 3->leftdown corner
# 4->rightdown corner, 5->rightmid ...
host.legend(loc=5)
# 设置颜色
host.axis["left"].label.set_color(p1.get_color())
par1.axis["right"].label.set_color(p3.get_color())
# 设置范围
host.set_xlim([0, num_epochs - 1])
plt.draw()
plt.show()
#--------------------------------------------------------------------------
# 第一部分 计算准确率 召回率 F值
#--------------------------------------------------------------------------
# 计算各类结果 共10类图片
real_8 = list(range(0, 8)) #真实10个类标数量的统计
pre_8 = list(range(0, 8)) #预测10个类标数量的统计
right_8 = list(range(0, 8)) #预测正确的10个类标数量
k = 0
while k < len(real):
v1 = int(real[k])
v2 = int(pre[k])
# print(v1, v2)
real_8[v1] = real_8[v1] + 1 # 计数
pre_8[v2] = pre_8[v2] + 1 # 计数
if v1==v2:
right_8[v1] = right_8[v1] + 1
k = k + 1
# print("统计各类数量")
# print(real_10, pre_10, right_10)
# 准确率 = 正确数 / 预测数
precision = list(range(0, 8))
k = 0
while k < len(real_8):
value = right_8[k] * 1.0 / pre_8[k]
precision[k] = value
k = k + 1
print('准确率: ')
print(precision)
# 召回率 = 正确数 / 真实数
recall = list(range(0, 8))
k = 0
while k < len(real_8):
value = right_8[k] * 1.0 / real_8[k]
recall[k] = value
k = k + 1
print('召回率: ')
print(recall)
# F值 = 2*准确率*召回率/(准确率+召回率)
f_measure = list(range(0, 8))
k = 0
while k < len(real_8):
value = (2 * precision[k] * recall[k] * 1.0) / (precision[k] + recall[k])
f_measure[k] = value
k = k + 1
print('F值: ')
print(f_measure)
#--------------------------------------------------------------------------
# 第二部分 绘制曲线
#--------------------------------------------------------------------------
# 设置类别
n_groups = 8
fig, ax = plt.subplots()
index = np.arange(n_groups)
bar_width = 0.2
opacity = 0.4
error_config = {
'ecolor': '0.3'}
# 用来正常显示中文标签
# plt.rcParams['font.sans-serif']=['SimHei']
# 绘制
rects1 = ax.bar(index, precision, bar_width,
alpha=opacity, color='b',
error_kw=error_config,
label='precision')
rects2 = ax.bar(index + bar_width, recall, bar_width,
alpha=opacity, color='m',
error_kw=error_config,
label='recall')
rects3 = ax.bar(index + bar_width + bar_width, f_measure, bar_width,
alpha=opacity, color='r',
error_kw=error_config,
label='f_measure')
# 设置标签
ax.set_xticks(index + 3 * bar_width / 3)
ax.set_xticklabels(('0-desk', '1-dining table', '2-double bed', '3-sofa', '4-squatting toilet',
'5-TV cabinet', '6-wardrobe', '7-washbasin'))
# 设置类标
ax.legend()
plt.xlabel("lable")
plt.ylabel("evaluation")
fig.set_figheight(5)
fig.set_figwidth(10)
fig.tight_layout()
# plt.savefig('result.png', dpi=200)
plt.show()
边栏推荐
- Number of consecutive subarrays with leetcode/ and K
- Control the pop-up window and no pop-up window of the input box
- What is browser fingerprint recognition
- 「活动推荐」冲冲冲!2022 国际开源节有新内容
- (arxiv-2018) 重新审视基于视频的 Person ReID 的时间建模
- 第十五天(VLAN相关知识)
- Qt源码分析--QObject(4)
- LM13丨形态量化-动量周期分析
- H5 background music is played automatically by touch
- 点击回到顶部js
猜你喜欢

JetPack--Navigation实现页面跳转

druid. io index_ Realtime real-time query

Introduction to shared data center agent

FPGA实现10M多功能信号发生器

字符流综合练习解题过程

Monadic linear function perceptron: Rosenblatt perceptron

webview攻击

“蔚来杯“2022牛客暑期多校训练营2,签到题GJK

Realization of digital tube display based on C51

druid. The performance of IO + tranquility real-time tasks is summarized with the help of 2020 double 11
随机推荐
Mathematical modeling -- red wine quality classification
Ciscn 2022 central China Misc
特殊流&Properties属性集实例遇到的问题及解决方法
How to crawl web pages with playwright?
Introduction to shared data center agent
12. < tag dynamic programming and subsequence, subarray> lt.72. edit distance
Feynman learning method (symbol table)
“蔚来杯“2022牛客暑期多校训练营2,签到题GJK
年中总结 | 与自己对话,活在当下,每走一步都算数
Rgbd point cloud down sampling
"Activity recommendation" rush rush! 2022 international open source Festival has new content
【云原生与5G】微服务加持5G核心网
Qt源码分析--QObject(4)
druid. The performance of IO + tranquility real-time tasks is summarized with the help of 2020 double 11
The problem of modifying the coordinate system of point cloud image loaded by ndtmatching function in autoware
基于C51实现数码管的显示
字符流综合练习解题过程
第十四天:续第十三天标签相关知识
Comprehensive analysis of news capture doorway
mobile-picker.js