当前位置:网站首页>Resnet50 + k-fold cross validation + data enhancement + drawing (accuracy, recall, F value)
Resnet50 + k-fold cross validation + data enhancement + drawing (accuracy, recall, F value)
2022-07-29 02:15:00 【Trouble coder】
k Crossover verification
def kfold(data, k=5):
""" K Crossover verification """
X = np.arange(len(data))
KF = KFold(n_splits=k,shuffle=True)
for train_idxs, valid_idxs in KF.split(X):
train_iter, valid_iter = [] , []
for i in train_idxs:
train_iter.append(data[i])
for i in valid_idxs:
valid_iter.append(data[i])
train_data = torch.utils.data.DataLoader(train_iter, shuffle=True,
batch_size = batch_size)
valid_data = torch.utils.data.DataLoader(valid_iter, batch_size = batch_size)
yield train_data, valid_data
Complete code
from glob import glob
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import shutil
from torchvision import transforms
from torchvision import models
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.optim import lr_scheduler
from torch import optim
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from sklearn.model_selection import KFold
from mpl_toolkits.axes_grid1 import host_subplot
import time
%matplotlib inline
def try_gpu(i=0):
""" If there is , Then return to gpu(i), Otherwise return to cpu()"""
if torch.cuda.device_count() >= i + 1:
return torch.device(f'cuda:{
i}')
return torch.device('cpu')
path = '../data/data-8/train/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total train of images {
len(files)}')
path = '../data/data-8/test/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total valid of images {
len(files)}')
imag_size = 224
batch_size = 16
# Data to enhance
transform = transforms.Compose([
transforms.Resize((imag_size, imag_size)),
transforms.RandomHorizontalFlip(),# Random horizontal flip
transforms.RandomVerticalFlip(), # Random vertical flip
transforms.RandomRotation(45), # Random angle rotation
transforms.RandomCrop((imag_size, imag_size)), # Random position cutting
transforms.ToTensor(),
# ,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
transforms.Resize((imag_size, imag_size)),
transforms.ToTensor(),
])
train_imgs = ImageFolder('../data/data-8/train', transform)
test_imgs = ImageFolder('../data/data-8/test', transform_test)
def kfold(data, k=5):
""" K Crossover verification """
X = np.arange(len(data))
KF = KFold(n_splits=k,shuffle=True)
for train_idxs, valid_idxs in KF.split(X):
train_iter, valid_iter = [] , []
for i in train_idxs:
train_iter.append(data[i])
for i in valid_idxs:
valid_iter.append(data[i])
train_data = torch.utils.data.DataLoader(train_iter, shuffle=True,
batch_size = batch_size)
valid_data = torch.utils.data.DataLoader(valid_iter, batch_size = batch_size)
yield train_data, valid_data
def train(data, isTrain=True):
if isTrain:
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in data:
if isTrain:
optimizer.zero_grad()
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if isTrain:
loss.backward()
optimizer.step()
running_loss += loss.data
running_corrects += torch.sum(preds == labels.data)
loss = running_loss / len(data) / batch_size
acc = running_corrects / len(data) / batch_size
return loss, acc
def test(data):
real_lables,pred_lables = [],[]
model.eval()
running_corrects = 0
for inputs, labels in data:
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
for y in labels:
real_lables.append(y.item())
for y in preds:
pred_lables.append(y.item())
running_corrects += torch.sum(preds == labels.data)
acc = running_corrects / len(data) / batch_size
return acc, running_corrects, real_lables, pred_lables
# Set super parameters
train_iterations, train_loss, test_accuracy =[], [], []
model = models.resnet50(pretrained=False)
# model_ft.load_state_dict(torch.load('data/resnet50-19c8e357.pth'))
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 8)
k = 5
lr, num_epochs = 2e-4, 20
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
device = try_gpu(0)
# Start training
print(' Training begins on', device)
model.to(device)
for epoch in range(num_epochs):
loss, acc = 0.0, 0.0
for train_data, valid_data in kfold(train_imgs, k): # k Crossover verification
train_losses, train_acc = train(train_data)
valid_losses, valid_acc = train(valid_data, False)
loss += valid_losses
acc += valid_acc
if epoch % 1 == 0: # every other 10 Output one result at a time
train_iterations.append(epoch)
train_loss.append((loss/k).to('cpu').item())
test_accuracy.append((acc/k).to('cpu').item())
print('{},{:.4f},{:.4f}'.format(epoch, loss/k, acc/k))
# print('Epoch {}/{} avgLoss: {:.4f} Acc: {:.4f}'.format(epoch + 1, num_epochs, loss/k, acc/k))
train_iterations.append(num_epochs - 1)
train_loss.append((loss/k).to('cpu').item())
test_accuracy.append((acc/k).to('cpu').item())
print('{},{:.4f},{:.4f}'.format(num_epochs - 1, loss/k, acc/k))
print(' End of training ')
torch.save(model, '../model/20220507-vgg16- Data to enhance -'+str(lr) +'.pth')
print(' test on', device)
model.to(device)
test_data = torch.utils.data.DataLoader(test_imgs, shuffle=True, batch_size=batch_size)
acc, corrects, real, pre = test(test_data)
print(' Accuracy rate : {:.4f} Correctly predict the number of : {}'.format(acc, corrects))
# k = 0
# while k < len(real_lables):
# print(real_lables[k], pred_lables[k])
# k = k + 1
# Draw a curve
host = host_subplot(111)
plt.subplots_adjust(right=0.8) # ajust the right boundary of the plot window
par1 = host.twinx()
# Set class labels
host.set_xlabel("iterations")
host.set_ylabel("loss")
par1.set_ylabel("validation accuracy")
# draw a curve
p1, = host.plot(train_iterations, train_loss, "b-", label="training loss")
p2, = host.plot(train_iterations, train_loss, ".") # Curve point
p3, = par1.plot(train_iterations, test_accuracy, label="validation accuracy")
p4, = par1.plot(train_iterations, test_accuracy, "1")
# Set icon
# 1->rightup corner, 2->leftup corner, 3->leftdown corner
# 4->rightdown corner, 5->rightmid ...
host.legend(loc=5)
# Set the color
host.axis["left"].label.set_color(p1.get_color())
par1.axis["right"].label.set_color(p3.get_color())
# set range
host.set_xlim([0, num_epochs - 1])
plt.draw()
plt.show()
#--------------------------------------------------------------------------
# The first part Computational accuracy Recall rate F value
#--------------------------------------------------------------------------
# Calculate all kinds of results common 10 Class picture
real_8 = list(range(0, 8)) # real 10 Statistics of the number of class labels
pre_8 = list(range(0, 8)) # forecast 10 Statistics of the number of class labels
right_8 = list(range(0, 8)) # The prediction is correct 10 Number of class labels
k = 0
while k < len(real):
v1 = int(real[k])
v2 = int(pre[k])
# print(v1, v2)
real_8[v1] = real_8[v1] + 1 # Count
pre_8[v2] = pre_8[v2] + 1 # Count
if v1==v2:
right_8[v1] = right_8[v1] + 1
k = k + 1
# print(" Count all kinds of quantity ")
# print(real_10, pre_10, right_10)
# Accuracy rate = Correct number / Forecast number
precision = list(range(0, 8))
k = 0
while k < len(real_8):
value = right_8[k] * 1.0 / pre_8[k]
precision[k] = value
k = k + 1
print(' Accuracy rate : ')
print(precision)
# Recall rate = Correct number / The real number
recall = list(range(0, 8))
k = 0
while k < len(real_8):
value = right_8[k] * 1.0 / real_8[k]
recall[k] = value
k = k + 1
print(' Recall rate : ')
print(recall)
# F value = 2* Accuracy rate * Recall rate /( Accuracy rate + Recall rate )
f_measure = list(range(0, 8))
k = 0
while k < len(real_8):
value = (2 * precision[k] * recall[k] * 1.0) / (precision[k] + recall[k])
f_measure[k] = value
k = k + 1
print('F value : ')
print(f_measure)
#--------------------------------------------------------------------------
# The second part draw a curve
#--------------------------------------------------------------------------
# Set category
n_groups = 8
fig, ax = plt.subplots()
index = np.arange(n_groups)
bar_width = 0.2
opacity = 0.4
error_config = {
'ecolor': '0.3'}
# Used to display Chinese labels normally
# plt.rcParams['font.sans-serif']=['SimHei']
# draw
rects1 = ax.bar(index, precision, bar_width,
alpha=opacity, color='b',
error_kw=error_config,
label='precision')
rects2 = ax.bar(index + bar_width, recall, bar_width,
alpha=opacity, color='m',
error_kw=error_config,
label='recall')
rects3 = ax.bar(index + bar_width + bar_width, f_measure, bar_width,
alpha=opacity, color='r',
error_kw=error_config,
label='f_measure')
# Set the label
ax.set_xticks(index + 3 * bar_width / 3)
ax.set_xticklabels(('0-desk', '1-dining table', '2-double bed', '3-sofa', '4-squatting toilet',
'5-TV cabinet', '6-wardrobe', '7-washbasin'))
# Set class labels
ax.legend()
plt.xlabel("lable")
plt.ylabel("evaluation")
fig.set_figheight(5)
fig.set_figwidth(10)
fig.tight_layout()
# plt.savefig('result.png', dpi=200)
plt.show()
边栏推荐
- 自定义mvc原理和框架实现
- 弹性布局 单选
- How to write, load and unload plug-ins in QT
- Type analysis of demultiplexer (demultiplexer)
- Mysql存储json格式数据
- 【ONE·Data || 链式二叉树】
- 解决使用ESlint时,和vetur冲突导致保存变双引号,结尾逗号等
- What is the function of data parsing?
- Leetcode/ and continuous shortest subarray greater than or equal to target
- How companies make business decisions -- with the help of data-driven marketing
猜你喜欢

【ONE·Data || 链式二叉树】

JetPack--Navigation实现页面跳转

Understand the working principle of timer in STM32 in simple terms

Wonderful use of data analysis

"Activity recommendation" rush rush! 2022 international open source Festival has new content

Comprehensive explanation of "search engine crawl"

What is scope and scope chain

Understand the clock tree in STM32 in simple terms

物联网开发--MQTT消息服务器EMQX

Detailed explanation of IVX low code platform series -- Overview (II)
随机推荐
Feynman learning method (symbol table)
第十四天:续第十三天标签相关知识
What is a proxy server? [2022 guide]
试着换个角度理解低代码平台设计的本质
Why can't Bi software do correlation analysis
Implementation of 10m multifunctional signal generator with FPGA
[electronic components] zener diode
autoware中ndtmatching功能加载点云图坐标系修正的问题
The growth path of embedded engineers
How companies make business decisions -- with the help of data-driven marketing
MySQL high performance optimization notes (including 578 pages of notes PDF document), collected
Qt 内存管理小技巧
Related function records about string processing (long-term update)
Solution of Lenovo notebook camera unable to open
Control buzzer based on C51
MySQL stores JSON format data
QT source code analysis -- QObject (4)
Verilog procedure assignment statements: blocking & non blocking
字符流综合练习解题过程
Click back to the top JS