当前位置:网站首页>ResNet50+k折交叉验证+数据增强+画图(准确率、召回率、F值)
ResNet50+k折交叉验证+数据增强+画图(准确率、召回率、F值)
2022-07-29 01:33:00 【落难Coder】
k折交叉验证
def kfold(data, k=5):
""" K折交叉验证 """
X = np.arange(len(data))
KF = KFold(n_splits=k,shuffle=True)
for train_idxs, valid_idxs in KF.split(X):
train_iter, valid_iter = [] , []
for i in train_idxs:
train_iter.append(data[i])
for i in valid_idxs:
valid_iter.append(data[i])
train_data = torch.utils.data.DataLoader(train_iter, shuffle=True,
batch_size = batch_size)
valid_data = torch.utils.data.DataLoader(valid_iter, batch_size = batch_size)
yield train_data, valid_data
完整代码
from glob import glob
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import shutil
from torchvision import transforms
from torchvision import models
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.optim import lr_scheduler
from torch import optim
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from sklearn.model_selection import KFold
from mpl_toolkits.axes_grid1 import host_subplot
import time
%matplotlib inline
def try_gpu(i=0):
"""如果存在,则返回gpu(i),否则返回cpu()"""
if torch.cuda.device_count() >= i + 1:
return torch.device(f'cuda:{
i}')
return torch.device('cpu')
path = '../data/data-8/train/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total train of images {
len(files)}')
path = '../data/data-8/test/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total valid of images {
len(files)}')
imag_size = 224
batch_size = 16
# 数据增强
transform = transforms.Compose([
transforms.Resize((imag_size, imag_size)),
transforms.RandomHorizontalFlip(),# 随机水平翻转
transforms.RandomVerticalFlip(), # 随机竖直翻转
transforms.RandomRotation(45), # 随机角度旋转
transforms.RandomCrop((imag_size, imag_size)), # 随机位置裁取
transforms.ToTensor(),
# ,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
transforms.Resize((imag_size, imag_size)),
transforms.ToTensor(),
])
train_imgs = ImageFolder('../data/data-8/train', transform)
test_imgs = ImageFolder('../data/data-8/test', transform_test)
def kfold(data, k=5):
""" K折交叉验证 """
X = np.arange(len(data))
KF = KFold(n_splits=k,shuffle=True)
for train_idxs, valid_idxs in KF.split(X):
train_iter, valid_iter = [] , []
for i in train_idxs:
train_iter.append(data[i])
for i in valid_idxs:
valid_iter.append(data[i])
train_data = torch.utils.data.DataLoader(train_iter, shuffle=True,
batch_size = batch_size)
valid_data = torch.utils.data.DataLoader(valid_iter, batch_size = batch_size)
yield train_data, valid_data
def train(data, isTrain=True):
if isTrain:
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in data:
if isTrain:
optimizer.zero_grad()
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if isTrain:
loss.backward()
optimizer.step()
running_loss += loss.data
running_corrects += torch.sum(preds == labels.data)
loss = running_loss / len(data) / batch_size
acc = running_corrects / len(data) / batch_size
return loss, acc
def test(data):
real_lables,pred_lables = [],[]
model.eval()
running_corrects = 0
for inputs, labels in data:
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
for y in labels:
real_lables.append(y.item())
for y in preds:
pred_lables.append(y.item())
running_corrects += torch.sum(preds == labels.data)
acc = running_corrects / len(data) / batch_size
return acc, running_corrects, real_lables, pred_lables
# 设置超参数
train_iterations, train_loss, test_accuracy =[], [], []
model = models.resnet50(pretrained=False)
# model_ft.load_state_dict(torch.load('data/resnet50-19c8e357.pth'))
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 8)
k = 5
lr, num_epochs = 2e-4, 20
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
device = try_gpu(0)
# 开始训练
print('训练开始 on', device)
model.to(device)
for epoch in range(num_epochs):
loss, acc = 0.0, 0.0
for train_data, valid_data in kfold(train_imgs, k): # k折交叉验证
train_losses, train_acc = train(train_data)
valid_losses, valid_acc = train(valid_data, False)
loss += valid_losses
acc += valid_acc
if epoch % 1 == 0: # 每隔10次输出一次结果
train_iterations.append(epoch)
train_loss.append((loss/k).to('cpu').item())
test_accuracy.append((acc/k).to('cpu').item())
print('{},{:.4f},{:.4f}'.format(epoch, loss/k, acc/k))
# print('Epoch {}/{} avgLoss: {:.4f} Acc: {:.4f}'.format(epoch + 1, num_epochs, loss/k, acc/k))
train_iterations.append(num_epochs - 1)
train_loss.append((loss/k).to('cpu').item())
test_accuracy.append((acc/k).to('cpu').item())
print('{},{:.4f},{:.4f}'.format(num_epochs - 1, loss/k, acc/k))
print('训练结束')
torch.save(model, '../model/20220507-vgg16-数据增强-'+str(lr) +'.pth')
print('测试 on', device)
model.to(device)
test_data = torch.utils.data.DataLoader(test_imgs, shuffle=True, batch_size=batch_size)
acc, corrects, real, pre = test(test_data)
print('准确率: {:.4f} 正确预测个数: {}'.format(acc, corrects))
# k = 0
# while k < len(real_lables):
# print(real_lables[k], pred_lables[k])
# k = k + 1
# 绘制曲线图
host = host_subplot(111)
plt.subplots_adjust(right=0.8) # ajust the right boundary of the plot window
par1 = host.twinx()
# 设置类标
host.set_xlabel("iterations")
host.set_ylabel("loss")
par1.set_ylabel("validation accuracy")
# 绘制曲线
p1, = host.plot(train_iterations, train_loss, "b-", label="training loss")
p2, = host.plot(train_iterations, train_loss, ".") #曲线点
p3, = par1.plot(train_iterations, test_accuracy, label="validation accuracy")
p4, = par1.plot(train_iterations, test_accuracy, "1")
# 设置图标
# 1->rightup corner, 2->leftup corner, 3->leftdown corner
# 4->rightdown corner, 5->rightmid ...
host.legend(loc=5)
# 设置颜色
host.axis["left"].label.set_color(p1.get_color())
par1.axis["right"].label.set_color(p3.get_color())
# 设置范围
host.set_xlim([0, num_epochs - 1])
plt.draw()
plt.show()
#--------------------------------------------------------------------------
# 第一部分 计算准确率 召回率 F值
#--------------------------------------------------------------------------
# 计算各类结果 共10类图片
real_8 = list(range(0, 8)) #真实10个类标数量的统计
pre_8 = list(range(0, 8)) #预测10个类标数量的统计
right_8 = list(range(0, 8)) #预测正确的10个类标数量
k = 0
while k < len(real):
v1 = int(real[k])
v2 = int(pre[k])
# print(v1, v2)
real_8[v1] = real_8[v1] + 1 # 计数
pre_8[v2] = pre_8[v2] + 1 # 计数
if v1==v2:
right_8[v1] = right_8[v1] + 1
k = k + 1
# print("统计各类数量")
# print(real_10, pre_10, right_10)
# 准确率 = 正确数 / 预测数
precision = list(range(0, 8))
k = 0
while k < len(real_8):
value = right_8[k] * 1.0 / pre_8[k]
precision[k] = value
k = k + 1
print('准确率: ')
print(precision)
# 召回率 = 正确数 / 真实数
recall = list(range(0, 8))
k = 0
while k < len(real_8):
value = right_8[k] * 1.0 / real_8[k]
recall[k] = value
k = k + 1
print('召回率: ')
print(recall)
# F值 = 2*准确率*召回率/(准确率+召回率)
f_measure = list(range(0, 8))
k = 0
while k < len(real_8):
value = (2 * precision[k] * recall[k] * 1.0) / (precision[k] + recall[k])
f_measure[k] = value
k = k + 1
print('F值: ')
print(f_measure)
#--------------------------------------------------------------------------
# 第二部分 绘制曲线
#--------------------------------------------------------------------------
# 设置类别
n_groups = 8
fig, ax = plt.subplots()
index = np.arange(n_groups)
bar_width = 0.2
opacity = 0.4
error_config = {
'ecolor': '0.3'}
# 用来正常显示中文标签
# plt.rcParams['font.sans-serif']=['SimHei']
# 绘制
rects1 = ax.bar(index, precision, bar_width,
alpha=opacity, color='b',
error_kw=error_config,
label='precision')
rects2 = ax.bar(index + bar_width, recall, bar_width,
alpha=opacity, color='m',
error_kw=error_config,
label='recall')
rects3 = ax.bar(index + bar_width + bar_width, f_measure, bar_width,
alpha=opacity, color='r',
error_kw=error_config,
label='f_measure')
# 设置标签
ax.set_xticks(index + 3 * bar_width / 3)
ax.set_xticklabels(('0-desk', '1-dining table', '2-double bed', '3-sofa', '4-squatting toilet',
'5-TV cabinet', '6-wardrobe', '7-washbasin'))
# 设置类标
ax.legend()
plt.xlabel("lable")
plt.ylabel("evaluation")
fig.set_figheight(5)
fig.set_figwidth(10)
fig.tight_layout()
# plt.savefig('result.png', dpi=200)
plt.show()
边栏推荐
猜你喜欢

What is the function of data parsing?

ciscn 2022 华中赛区 misc

12.< tag-动态规划和子序列, 子数组>lt.72. 编辑距离

Realization of digital tube display based on C51

Navigation--实现Fragment之间数据传递和数据共享

MySQL stores JSON format data

Understand the clock tree in STM32 in simple terms

数学建模——红酒品质分类

Comprehensive explanation of "search engine crawl"

数学建模——永冻土层上关于路基热传导问题
随机推荐
表单校验 隐藏的输入框 显示才校验
Complete collection of common error handling in MySQL installation
Leetcode exercise - Sword finger offer 45. arrange the array into the smallest number
How to find the right agent type? Multi angle analysis for you!
QT memory management tips
12.< tag-动态规划和子序列, 子数组>lt.72. 编辑距离
Qt 内存管理小技巧
(arxiv-2018) reexamine the time modeling of person Reid based on video
Sharpness evaluation method without reference image
The number of consecutive subarrays whose leetcode/ product is less than k
Cookie和Session
Promise solves asynchrony
"Wei Lai Cup" 2022 Niuke summer multi school training camp 2, sign in question GJK
防止重复点击
自定义mvc原理和框架实现
Jetpack--了解ViewModel和LiveData的使用
leetcode/和为k的连续子数组个数
Try to understand the essence of low code platform design from another angle
Understand the working principle of timer in STM32 in simple terms
MySQL stores JSON format data