当前位置:网站首页>ResNet50+k折交叉验证+数据增强+画图(准确率、召回率、F值)
ResNet50+k折交叉验证+数据增强+画图(准确率、召回率、F值)
2022-07-29 01:33:00 【落难Coder】
k折交叉验证
def kfold(data, k=5):
""" K折交叉验证 """
X = np.arange(len(data))
KF = KFold(n_splits=k,shuffle=True)
for train_idxs, valid_idxs in KF.split(X):
train_iter, valid_iter = [] , []
for i in train_idxs:
train_iter.append(data[i])
for i in valid_idxs:
valid_iter.append(data[i])
train_data = torch.utils.data.DataLoader(train_iter, shuffle=True,
batch_size = batch_size)
valid_data = torch.utils.data.DataLoader(valid_iter, batch_size = batch_size)
yield train_data, valid_data
完整代码
from glob import glob
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import shutil
from torchvision import transforms
from torchvision import models
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.optim import lr_scheduler
from torch import optim
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from sklearn.model_selection import KFold
from mpl_toolkits.axes_grid1 import host_subplot
import time
%matplotlib inline
def try_gpu(i=0):
"""如果存在,则返回gpu(i),否则返回cpu()"""
if torch.cuda.device_count() >= i + 1:
return torch.device(f'cuda:{
i}')
return torch.device('cpu')
path = '../data/data-8/train/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total train of images {
len(files)}')
path = '../data/data-8/test/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total valid of images {
len(files)}')
imag_size = 224
batch_size = 16
# 数据增强
transform = transforms.Compose([
transforms.Resize((imag_size, imag_size)),
transforms.RandomHorizontalFlip(),# 随机水平翻转
transforms.RandomVerticalFlip(), # 随机竖直翻转
transforms.RandomRotation(45), # 随机角度旋转
transforms.RandomCrop((imag_size, imag_size)), # 随机位置裁取
transforms.ToTensor(),
# ,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
transforms.Resize((imag_size, imag_size)),
transforms.ToTensor(),
])
train_imgs = ImageFolder('../data/data-8/train', transform)
test_imgs = ImageFolder('../data/data-8/test', transform_test)
def kfold(data, k=5):
""" K折交叉验证 """
X = np.arange(len(data))
KF = KFold(n_splits=k,shuffle=True)
for train_idxs, valid_idxs in KF.split(X):
train_iter, valid_iter = [] , []
for i in train_idxs:
train_iter.append(data[i])
for i in valid_idxs:
valid_iter.append(data[i])
train_data = torch.utils.data.DataLoader(train_iter, shuffle=True,
batch_size = batch_size)
valid_data = torch.utils.data.DataLoader(valid_iter, batch_size = batch_size)
yield train_data, valid_data
def train(data, isTrain=True):
if isTrain:
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in data:
if isTrain:
optimizer.zero_grad()
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if isTrain:
loss.backward()
optimizer.step()
running_loss += loss.data
running_corrects += torch.sum(preds == labels.data)
loss = running_loss / len(data) / batch_size
acc = running_corrects / len(data) / batch_size
return loss, acc
def test(data):
real_lables,pred_lables = [],[]
model.eval()
running_corrects = 0
for inputs, labels in data:
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
for y in labels:
real_lables.append(y.item())
for y in preds:
pred_lables.append(y.item())
running_corrects += torch.sum(preds == labels.data)
acc = running_corrects / len(data) / batch_size
return acc, running_corrects, real_lables, pred_lables
# 设置超参数
train_iterations, train_loss, test_accuracy =[], [], []
model = models.resnet50(pretrained=False)
# model_ft.load_state_dict(torch.load('data/resnet50-19c8e357.pth'))
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 8)
k = 5
lr, num_epochs = 2e-4, 20
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
device = try_gpu(0)
# 开始训练
print('训练开始 on', device)
model.to(device)
for epoch in range(num_epochs):
loss, acc = 0.0, 0.0
for train_data, valid_data in kfold(train_imgs, k): # k折交叉验证
train_losses, train_acc = train(train_data)
valid_losses, valid_acc = train(valid_data, False)
loss += valid_losses
acc += valid_acc
if epoch % 1 == 0: # 每隔10次输出一次结果
train_iterations.append(epoch)
train_loss.append((loss/k).to('cpu').item())
test_accuracy.append((acc/k).to('cpu').item())
print('{},{:.4f},{:.4f}'.format(epoch, loss/k, acc/k))
# print('Epoch {}/{} avgLoss: {:.4f} Acc: {:.4f}'.format(epoch + 1, num_epochs, loss/k, acc/k))
train_iterations.append(num_epochs - 1)
train_loss.append((loss/k).to('cpu').item())
test_accuracy.append((acc/k).to('cpu').item())
print('{},{:.4f},{:.4f}'.format(num_epochs - 1, loss/k, acc/k))
print('训练结束')
torch.save(model, '../model/20220507-vgg16-数据增强-'+str(lr) +'.pth')
print('测试 on', device)
model.to(device)
test_data = torch.utils.data.DataLoader(test_imgs, shuffle=True, batch_size=batch_size)
acc, corrects, real, pre = test(test_data)
print('准确率: {:.4f} 正确预测个数: {}'.format(acc, corrects))
# k = 0
# while k < len(real_lables):
# print(real_lables[k], pred_lables[k])
# k = k + 1
# 绘制曲线图
host = host_subplot(111)
plt.subplots_adjust(right=0.8) # ajust the right boundary of the plot window
par1 = host.twinx()
# 设置类标
host.set_xlabel("iterations")
host.set_ylabel("loss")
par1.set_ylabel("validation accuracy")
# 绘制曲线
p1, = host.plot(train_iterations, train_loss, "b-", label="training loss")
p2, = host.plot(train_iterations, train_loss, ".") #曲线点
p3, = par1.plot(train_iterations, test_accuracy, label="validation accuracy")
p4, = par1.plot(train_iterations, test_accuracy, "1")
# 设置图标
# 1->rightup corner, 2->leftup corner, 3->leftdown corner
# 4->rightdown corner, 5->rightmid ...
host.legend(loc=5)
# 设置颜色
host.axis["left"].label.set_color(p1.get_color())
par1.axis["right"].label.set_color(p3.get_color())
# 设置范围
host.set_xlim([0, num_epochs - 1])
plt.draw()
plt.show()
#--------------------------------------------------------------------------
# 第一部分 计算准确率 召回率 F值
#--------------------------------------------------------------------------
# 计算各类结果 共10类图片
real_8 = list(range(0, 8)) #真实10个类标数量的统计
pre_8 = list(range(0, 8)) #预测10个类标数量的统计
right_8 = list(range(0, 8)) #预测正确的10个类标数量
k = 0
while k < len(real):
v1 = int(real[k])
v2 = int(pre[k])
# print(v1, v2)
real_8[v1] = real_8[v1] + 1 # 计数
pre_8[v2] = pre_8[v2] + 1 # 计数
if v1==v2:
right_8[v1] = right_8[v1] + 1
k = k + 1
# print("统计各类数量")
# print(real_10, pre_10, right_10)
# 准确率 = 正确数 / 预测数
precision = list(range(0, 8))
k = 0
while k < len(real_8):
value = right_8[k] * 1.0 / pre_8[k]
precision[k] = value
k = k + 1
print('准确率: ')
print(precision)
# 召回率 = 正确数 / 真实数
recall = list(range(0, 8))
k = 0
while k < len(real_8):
value = right_8[k] * 1.0 / real_8[k]
recall[k] = value
k = k + 1
print('召回率: ')
print(recall)
# F值 = 2*准确率*召回率/(准确率+召回率)
f_measure = list(range(0, 8))
k = 0
while k < len(real_8):
value = (2 * precision[k] * recall[k] * 1.0) / (precision[k] + recall[k])
f_measure[k] = value
k = k + 1
print('F值: ')
print(f_measure)
#--------------------------------------------------------------------------
# 第二部分 绘制曲线
#--------------------------------------------------------------------------
# 设置类别
n_groups = 8
fig, ax = plt.subplots()
index = np.arange(n_groups)
bar_width = 0.2
opacity = 0.4
error_config = {
'ecolor': '0.3'}
# 用来正常显示中文标签
# plt.rcParams['font.sans-serif']=['SimHei']
# 绘制
rects1 = ax.bar(index, precision, bar_width,
alpha=opacity, color='b',
error_kw=error_config,
label='precision')
rects2 = ax.bar(index + bar_width, recall, bar_width,
alpha=opacity, color='m',
error_kw=error_config,
label='recall')
rects3 = ax.bar(index + bar_width + bar_width, f_measure, bar_width,
alpha=opacity, color='r',
error_kw=error_config,
label='f_measure')
# 设置标签
ax.set_xticks(index + 3 * bar_width / 3)
ax.set_xticklabels(('0-desk', '1-dining table', '2-double bed', '3-sofa', '4-squatting toilet',
'5-TV cabinet', '6-wardrobe', '7-washbasin'))
# 设置类标
ax.legend()
plt.xlabel("lable")
plt.ylabel("evaluation")
fig.set_figheight(5)
fig.set_figwidth(10)
fig.tight_layout()
# plt.savefig('result.png', dpi=200)
plt.show()
边栏推荐
- Wonderful use of data analysis
- iVX低代码平台系列详解 -- 概述篇(二)
- 物联网开发--MQTT消息服务器EMQX
- [electronic components] constant voltage, amplify the current of the load (triode knowledge summary)
- druid. io kill -9 index_ Realtime traceability task
- Mobile communication -- simulation model of error control system based on convolutional code
- Type analysis of demultiplexer (demultiplexer)
- 数学建模——自来水管道铺设问题
- 自定义mvc原理和框架实现
- Comprehensive explanation of "search engine crawl"
猜你喜欢

druid. io kill -9 index_ Realtime traceability task

Control buzzer based on C51

FPGA实现10M多功能信号发生器

基于C51控制蜂鸣器

RGBD点云降采样

Mathematical modeling -- heat conduction of subgrade on Permafrost

druid. io index_ Realtime real-time query

基于C51实现数码管的显示

Type analysis of demultiplexer (demultiplexer)

Lxml web page capture the most complete strategy
随机推荐
Idea connection database
[simple implementation and extension of one · data | array heap]
Probability Density Reweight
Using local cache + global cache to realize user rights management of small systems
mobile-picker.js
[MySQL] SQL aliases the table
2022.7.27-----leetcode.592
Form verification hidden input box is displayed before verification
leetcode 242. Valid Anagram(有效的字母异位词)
数学建模——带相变材料的低温防护服御寒仿真模拟
Comprehensive use method of C treeview control
Resolve the conflict with vetur when using eslint, resulting in double quotation marks and comma at the end of saving
Wonderful use of data analysis
Monadic linear function perceptron: Rosenblatt perceptron
Comprehensive analysis of news capture doorway
[cloud native] what is the microservice architecture
「活动推荐」冲冲冲!2022 国际开源节有新内容
Click back to the top JS
2022.7.28-----leetcode.1331
Promise solves asynchrony