当前位置:网站首页>ResNet50+k折交叉验证+数据增强+画图(准确率、召回率、F值)
ResNet50+k折交叉验证+数据增强+画图(准确率、召回率、F值)
2022-07-29 01:33:00 【落难Coder】
k折交叉验证
def kfold(data, k=5):
""" K折交叉验证 """
X = np.arange(len(data))
KF = KFold(n_splits=k,shuffle=True)
for train_idxs, valid_idxs in KF.split(X):
train_iter, valid_iter = [] , []
for i in train_idxs:
train_iter.append(data[i])
for i in valid_idxs:
valid_iter.append(data[i])
train_data = torch.utils.data.DataLoader(train_iter, shuffle=True,
batch_size = batch_size)
valid_data = torch.utils.data.DataLoader(valid_iter, batch_size = batch_size)
yield train_data, valid_data
完整代码
from glob import glob
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import shutil
from torchvision import transforms
from torchvision import models
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.optim import lr_scheduler
from torch import optim
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from sklearn.model_selection import KFold
from mpl_toolkits.axes_grid1 import host_subplot
import time
%matplotlib inline
def try_gpu(i=0):
"""如果存在,则返回gpu(i),否则返回cpu()"""
if torch.cuda.device_count() >= i + 1:
return torch.device(f'cuda:{
i}')
return torch.device('cpu')
path = '../data/data-8/train/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total train of images {
len(files)}')
path = '../data/data-8/test/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total valid of images {
len(files)}')
imag_size = 224
batch_size = 16
# 数据增强
transform = transforms.Compose([
transforms.Resize((imag_size, imag_size)),
transforms.RandomHorizontalFlip(),# 随机水平翻转
transforms.RandomVerticalFlip(), # 随机竖直翻转
transforms.RandomRotation(45), # 随机角度旋转
transforms.RandomCrop((imag_size, imag_size)), # 随机位置裁取
transforms.ToTensor(),
# ,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
transforms.Resize((imag_size, imag_size)),
transforms.ToTensor(),
])
train_imgs = ImageFolder('../data/data-8/train', transform)
test_imgs = ImageFolder('../data/data-8/test', transform_test)
def kfold(data, k=5):
""" K折交叉验证 """
X = np.arange(len(data))
KF = KFold(n_splits=k,shuffle=True)
for train_idxs, valid_idxs in KF.split(X):
train_iter, valid_iter = [] , []
for i in train_idxs:
train_iter.append(data[i])
for i in valid_idxs:
valid_iter.append(data[i])
train_data = torch.utils.data.DataLoader(train_iter, shuffle=True,
batch_size = batch_size)
valid_data = torch.utils.data.DataLoader(valid_iter, batch_size = batch_size)
yield train_data, valid_data
def train(data, isTrain=True):
if isTrain:
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in data:
if isTrain:
optimizer.zero_grad()
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if isTrain:
loss.backward()
optimizer.step()
running_loss += loss.data
running_corrects += torch.sum(preds == labels.data)
loss = running_loss / len(data) / batch_size
acc = running_corrects / len(data) / batch_size
return loss, acc
def test(data):
real_lables,pred_lables = [],[]
model.eval()
running_corrects = 0
for inputs, labels in data:
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
for y in labels:
real_lables.append(y.item())
for y in preds:
pred_lables.append(y.item())
running_corrects += torch.sum(preds == labels.data)
acc = running_corrects / len(data) / batch_size
return acc, running_corrects, real_lables, pred_lables
# 设置超参数
train_iterations, train_loss, test_accuracy =[], [], []
model = models.resnet50(pretrained=False)
# model_ft.load_state_dict(torch.load('data/resnet50-19c8e357.pth'))
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 8)
k = 5
lr, num_epochs = 2e-4, 20
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
device = try_gpu(0)
# 开始训练
print('训练开始 on', device)
model.to(device)
for epoch in range(num_epochs):
loss, acc = 0.0, 0.0
for train_data, valid_data in kfold(train_imgs, k): # k折交叉验证
train_losses, train_acc = train(train_data)
valid_losses, valid_acc = train(valid_data, False)
loss += valid_losses
acc += valid_acc
if epoch % 1 == 0: # 每隔10次输出一次结果
train_iterations.append(epoch)
train_loss.append((loss/k).to('cpu').item())
test_accuracy.append((acc/k).to('cpu').item())
print('{},{:.4f},{:.4f}'.format(epoch, loss/k, acc/k))
# print('Epoch {}/{} avgLoss: {:.4f} Acc: {:.4f}'.format(epoch + 1, num_epochs, loss/k, acc/k))
train_iterations.append(num_epochs - 1)
train_loss.append((loss/k).to('cpu').item())
test_accuracy.append((acc/k).to('cpu').item())
print('{},{:.4f},{:.4f}'.format(num_epochs - 1, loss/k, acc/k))
print('训练结束')
torch.save(model, '../model/20220507-vgg16-数据增强-'+str(lr) +'.pth')
print('测试 on', device)
model.to(device)
test_data = torch.utils.data.DataLoader(test_imgs, shuffle=True, batch_size=batch_size)
acc, corrects, real, pre = test(test_data)
print('准确率: {:.4f} 正确预测个数: {}'.format(acc, corrects))
# k = 0
# while k < len(real_lables):
# print(real_lables[k], pred_lables[k])
# k = k + 1
# 绘制曲线图
host = host_subplot(111)
plt.subplots_adjust(right=0.8) # ajust the right boundary of the plot window
par1 = host.twinx()
# 设置类标
host.set_xlabel("iterations")
host.set_ylabel("loss")
par1.set_ylabel("validation accuracy")
# 绘制曲线
p1, = host.plot(train_iterations, train_loss, "b-", label="training loss")
p2, = host.plot(train_iterations, train_loss, ".") #曲线点
p3, = par1.plot(train_iterations, test_accuracy, label="validation accuracy")
p4, = par1.plot(train_iterations, test_accuracy, "1")
# 设置图标
# 1->rightup corner, 2->leftup corner, 3->leftdown corner
# 4->rightdown corner, 5->rightmid ...
host.legend(loc=5)
# 设置颜色
host.axis["left"].label.set_color(p1.get_color())
par1.axis["right"].label.set_color(p3.get_color())
# 设置范围
host.set_xlim([0, num_epochs - 1])
plt.draw()
plt.show()
#--------------------------------------------------------------------------
# 第一部分 计算准确率 召回率 F值
#--------------------------------------------------------------------------
# 计算各类结果 共10类图片
real_8 = list(range(0, 8)) #真实10个类标数量的统计
pre_8 = list(range(0, 8)) #预测10个类标数量的统计
right_8 = list(range(0, 8)) #预测正确的10个类标数量
k = 0
while k < len(real):
v1 = int(real[k])
v2 = int(pre[k])
# print(v1, v2)
real_8[v1] = real_8[v1] + 1 # 计数
pre_8[v2] = pre_8[v2] + 1 # 计数
if v1==v2:
right_8[v1] = right_8[v1] + 1
k = k + 1
# print("统计各类数量")
# print(real_10, pre_10, right_10)
# 准确率 = 正确数 / 预测数
precision = list(range(0, 8))
k = 0
while k < len(real_8):
value = right_8[k] * 1.0 / pre_8[k]
precision[k] = value
k = k + 1
print('准确率: ')
print(precision)
# 召回率 = 正确数 / 真实数
recall = list(range(0, 8))
k = 0
while k < len(real_8):
value = right_8[k] * 1.0 / real_8[k]
recall[k] = value
k = k + 1
print('召回率: ')
print(recall)
# F值 = 2*准确率*召回率/(准确率+召回率)
f_measure = list(range(0, 8))
k = 0
while k < len(real_8):
value = (2 * precision[k] * recall[k] * 1.0) / (precision[k] + recall[k])
f_measure[k] = value
k = k + 1
print('F值: ')
print(f_measure)
#--------------------------------------------------------------------------
# 第二部分 绘制曲线
#--------------------------------------------------------------------------
# 设置类别
n_groups = 8
fig, ax = plt.subplots()
index = np.arange(n_groups)
bar_width = 0.2
opacity = 0.4
error_config = {
'ecolor': '0.3'}
# 用来正常显示中文标签
# plt.rcParams['font.sans-serif']=['SimHei']
# 绘制
rects1 = ax.bar(index, precision, bar_width,
alpha=opacity, color='b',
error_kw=error_config,
label='precision')
rects2 = ax.bar(index + bar_width, recall, bar_width,
alpha=opacity, color='m',
error_kw=error_config,
label='recall')
rects3 = ax.bar(index + bar_width + bar_width, f_measure, bar_width,
alpha=opacity, color='r',
error_kw=error_config,
label='f_measure')
# 设置标签
ax.set_xticks(index + 3 * bar_width / 3)
ax.set_xticklabels(('0-desk', '1-dining table', '2-double bed', '3-sofa', '4-squatting toilet',
'5-TV cabinet', '6-wardrobe', '7-washbasin'))
# 设置类标
ax.legend()
plt.xlabel("lable")
plt.ylabel("evaluation")
fig.set_figheight(5)
fig.set_figwidth(10)
fig.tight_layout()
# plt.savefig('result.png', dpi=200)
plt.show()
边栏推荐
- 数学建模——自来水管道铺设问题
- Mathematical modeling -- bus scheduling optimization
- Resolve the conflict with vetur when using eslint, resulting in double quotation marks and comma at the end of saving
- Motionlayout -- realize animation in visual editor
- Mobile communication -- simulation model of error control system based on convolutional code
- Probability Density Reweight
- Ciscn 2022 central China Misc
- Mathematical modeling -- Optimization of picking in warehouse
- Mysql存储json格式数据
- Control buzzer based on C51
猜你喜欢

第十四天:续第十三天标签相关知识

基于 ICA 与 DL 的语音信号盲分离

Understand the working principle of timer in STM32 in simple terms

【RT学习笔记1】RT-Thread外设例程——控制Led灯闪烁

年中总结 | 与自己对话,活在当下,每走一步都算数

Comprehensive explanation of "search engine crawl"

基于C51实现数码管的显示

Using local cache + global cache to realize user rights management of small systems

In 2022, the official data of programming language ranking came, which was an eye opener

Have you ever encountered the situation that the IP is blocked when crawling web pages?
随机推荐
(arxiv-2018) reexamine the time modeling of person Reid based on video
Have you ever encountered the situation that the IP is blocked when crawling web pages?
【RT学习笔记1】RT-Thread外设例程——控制Led灯闪烁
Qt源码分析--QObject(4)
基于C51实现数码管的显示
年中总结 | 与自己对话,活在当下,每走一步都算数
数学建模——自来水管道铺设问题
mobile-picker.js
druid. IO custom real-time task scheduling policy
leetcode/乘积小于K 的连续子数组的个数
Rgbd point cloud down sampling
控制输入框弹出弹窗 和不弹出窗口
Control the pop-up window and no pop-up window of the input box
leetcode 242. Valid Anagram(有效的字母异位词)
[云原生]微服务架构是什么
Blind separation of speech signals based on ICA and DL
“蔚来杯“2022牛客暑期多校训练营3,签到题CAJHF
基于C51控制蜂鸣器
TI C6000 TMS320C6678 DSP+ Zynq-7045的PS + PL异构多核案例开发手册(2)
Click back to the top JS