当前位置:网站首页>pytorch模型微调finetuning训练image_dog(kaggle)
pytorch模型微调finetuning训练image_dog(kaggle)
2022-07-31 05:16:00 【王大队长】
对kaggle上的比赛 狗的品种识别进行训练。



train文件夹有10222张狗的图片,test文件夹有10357张狗的图片,labels.csv文件放着train文件夹里的图片名称对应的label值。
模型:选择预训练的resnet50,冻住前面层的参数,将最后的全连接层改变使得输出类别变成120并对该层进行学习
代码:
import torch
import torchvision
from torch import nn
from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision.transforms import transforms
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm #显示进度条的库
class dataset(Dataset): # 获取文件夹里的图片,并返回图片和label
def __init__(self, root_dir, imgs_labels, transform=None, valid_ratio=0.2, mode='train'):
self.imgName = os.listdir(root_dir)
self.data_len = len(self.imgName) - 1
self.imgs_labels = imgs_labels
self.transform = transform
self.root_dir = root_dir
self.mode = mode
self.train_len = int(self.data_len * (1 - valid_ratio))
if mode == 'train':
self.real_len = int(len(self.imgName) * (1 - valid_ratio))
elif mode == 'valid':
self.real_len = self.data_len - self.train_len
elif mode == 'test':
self.real_len = len(self.imgName)
print('Finished reading the {} set of Dataset ({} samples found)'
.format(mode, self.real_len))
def __getitem__(self, idx):
if self.mode == 'train':
self.imgtrainName = self.imgName[0:self.train_len]
self.real_len = len(self.imgtrainName)
img_item_path = os.path.join(self.root_dir, self.imgtrainName[idx])
img = Image.open(img_item_path).convert('RGB') # 转换成RGB形式的三通道图像
if self.transform is not None:
img = self.transform(img)
label = self.imgs_labels[self.imgtrainName[idx].split('.')[0]]
num_label = class_to_num[label]
return img, num_label
elif self.mode == 'valid':
self.imgvalidName = self.imgName[self.train_len:]
img_item_path = os.path.join(self.root_dir, self.imgvalidName[idx])
img = Image.open(img_item_path).convert('RGB') # 转换成RGB形式的三通道图像
if self.transform is not None:
img = self.transform(img)
label = self.imgs_labels[self.imgvalidName[idx].split('.')[0]]
num_label = class_to_num[label]
return img, num_label
elif self.mode == 'test':
img_item_path = os.path.join(self.root_dir, self.imgPath[idx])
img = Image.open(img_item_path).convert('RGB') # 转换成RGB形式的三通道图像
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return self.real_len
def get_net(device):
finetune_net = nn.Sequential()
finetune_net.features = torchvision.models.resnet50(pretrained=True)
finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256),
nn.ReLU(),
nn.Linear(256, 120))
# 将模型参数分配给用于计算的CPU或GPU
finetune_net = finetune_net.to(device)
# 冻结参数
for param in finetune_net.features.parameters():
param.requires_grad = False
return finetune_net
def train(net, loss, optimizer, train_dataloader, device, batch_size):
# 训练开始
net.train()
train_acc = 0
# for imgs, targets in train_dataloader:
for batch in tqdm(train_dataloader, desc='训练'):
imgs, targets = batch
imgs = imgs.to(device)
targets = targets.to(device)
output = net(imgs)
Loss = loss(output, targets)
# 优化器优化模型
optimizer.zero_grad()
Loss.backward()
optimizer.step()
_, pred = output.max(1)
num_correct = (pred == targets).sum().item()
acc = num_correct / (batch_size)
train_acc += acc
print("epoch: {}, Loss: {}, Acc: {}".format(epoch, Loss.item(), train_acc / len(train_dataloader)))
train_acces.append(train_acc / len(train_dataloader))
train_losses.append(Loss.item())
return train_losses, train_acces
def valid(net, loss, valid_dataloader, device, batch_size):
net.eval()
eval_loss = 0
eval_acc = 0
with torch.no_grad():
# for imgs, targets in test_dataloader:
for batch in tqdm(valid_dataloader, desc='验证'):
imgs, targets = batch
imgs = imgs.to(device)
targets = targets.to(device)
output = net(imgs)
Loss = loss(output, targets)
_, pred = output.max(1)
num_correct = (pred == targets).sum().item()
eval_loss += Loss
acc = num_correct / (batch_size)
eval_acc += acc
eval_losses = eval_loss / (len(valid_dataloader))
eval_acc = eval_acc / (len(valid_dataloader))
eval_acces.append(eval_acc)
print("整体验证集上的Loss: {}".format(eval_losses))
print("整体验证集上的正确率: {}".format(eval_acc))
return eval_acces
def show_acces(train_losses, train_acces, valid_acces, num_epoch):#对准确率和loss画图显得直观
plt.plot(1 + np.arange(len(train_losses)), train_losses, linewidth=1.5, linestyle='dashed', label='train_losses')
plt.plot(1 + np.arange(len(train_acces)), train_acces, linewidth=1.5, linestyle='dashed', label='train_acces')
plt.plot(1 + np.arange(len(valid_acces)), valid_acces, linewidth=1.5, linestyle='dashed', label='valid_acces')
plt.grid()
plt.xlabel('epoch')
# plt.xlim((1,num_epoch))
plt.xticks(range(1, 1 + num_epoch, 1))
plt.legend()
plt.show()
if __name__ == '__main__':
train_path = r'/kaggle/input/dogbreedidentification/train/'
test_path = r'/kaggle/input/dogbreedidentification/test/'
imgs_train = os.listdir(train_path)
imgs_test = os.listdir(test_path)
print('训练集图片数:', len(imgs_train))
print('验证集图片数:', len(imgs_test))
data_info = pd.read_csv(r'/kaggle/input/dogbreedidentification/labels.csv')
# print(data_info.sample(10))
imgs = data_info.iloc[0:, 0]
labels = data_info.iloc[0:, 1]
imgs_labels = dict(zip(imgs, labels))
train_root_dir = train_path
test_root_dir = test_path
labels = imgs_labels
# 独热编码
labels_sorted = sorted(list(set(data_info['breed'])))
n_classes = len(labels_sorted)
class_to_num = dict(zip(labels_sorted, range(n_classes)))
print('len:', len(labels_sorted))
num_to_class = {v: k for k, v in class_to_num.items()}
# config
batchsize = 64
num_epoch = 5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = get_net(device)
net.to(device)
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam([{'params': net.output_new.parameters()}], lr=0.001)
train_losses = []
train_acces = []
eval_acces = []
transform = {"train": transforms.Compose([
# 随机裁剪图像,所得图像为原始面积的0.08到1之间,高宽比在3/4和4/3之间。
# 然后,缩放图像以创建224x224的新图像
transforms.RandomResizedCrop(224, scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0)),
transforms.RandomHorizontalFlip(),
# 随机更改亮度,对比度和饱和度
transforms.ColorJitter(brightness=0.4,
contrast=0.4,
saturation=0.4),
# 添加随机噪声
transforms.ToTensor(),
# 标准化图像的每个通道
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])]),
"valid": transforms.Compose([
torchvision.transforms.Resize(256),
# 从图像中心裁切224x224大小的图片
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])}
train_dataset = dataset(train_root_dir, imgs_labels, transform=transform["train"], valid_ratio=0.1, mode='train')
valid_dataset = dataset(train_root_dir, imgs_labels, transform=transform["valid"], valid_ratio=0.1, mode='valid')
test_dataset = dataset(test_root_dir, imgs_labels, transform=transform_train, mode='test')
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=int(batchsize / 2), shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
#放几张样本图片看看
examples = enumerate(train_dataloader) # 组合成一个索引序列
batch_idx, (example_data, example_targets) = next(examples)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
# plt.tight_layout()
img = example_data[i]
img = np.asarray(img)
plt.imshow(img.transpose((1, 2, 0)))
plt.title(num_to_class[int(example_targets[i])])
plt.xticks([])
plt.yticks([])
plt.show()
print("参数数:{}".format(sum(x.numel() for x in net.parameters())))
for epoch in range(num_epoch):
print("——————第 {} 轮训练开始——————".format(epoch + 1))
train_losses, train_acces = train(net, loss, optimizer, train_dataloader, device, batchsize)
valid_acces = valid(net, loss, valid_dataloader, device, batchsize / 2)
show_acces(train_losses, train_acces, valid_acces, num_epoch)结果:

编辑
添加图片注释,不超过 140 字(可选)

编辑
添加图片注释,不超过 140 字(可选)
5个epoch使验证集上准确率0.83,不过2到3个epoch就已经收敛了(可能是这就是微调的吧,冻住了其他参数,就学习了最后的Linear层所以参数不多)。
边栏推荐
- cv2.imread()
- 安装Multisim出现 No software will be installed or removed解决方法
- MySQL compressed package installation, fool teaching
- Gradle sync failed: Uninitialized object exists on backward branch 142
- 360 hardening file path not exists.
- js中流程控制语句
- js中的函数
- MySQL面试题大全(陆续更新)
- [windows]--- SQL Server 2008 super detailed installation tutorial
- powershell统计文件夹大小
猜你喜欢

QT VS中双击ui文件无法打开的问题

Sqlite A列数据复制到B列

2021美赛C题M奖思路
![[windows]--- SQL Server 2008 super detailed installation tutorial](/img/b7/dc802c63b07edc4298b6e6b90d865c.png)
[windows]--- SQL Server 2008 super detailed installation tutorial

MySQL错误-this is incompatible with sql_mode=only_full_group_by完美解决方案

Pytorch实现ResNet

Gradle sync failed: Uninitialized object exists on backward branch 142

js中的全局作用域与函数作用域

JS写一段代码,判断一个字符串中出现次数最多的字符串,并统计出现的次数JS

活体检测FaceBagNet阅读笔记
随机推荐
understand js operators
sql add default constraint
Markdown help documentation
MySQL高级SQL语句(二)
cocos2d-x-3.2图片灰化效果
sql 添加 default 约束
Principle analysis of famous website msdn.itellyou.cn
unicloud 发布后小程序提示连接本地调试服务失败,请检查客户端是否和主机在同一局域网下
RuntimeError: CUDA error: no kernel image is available for execution on the device问题记录
cocos2d-x-3.2 Physics
configure:error no SDL library found
UiBot存在已打开的MicrosoftEdge浏览器,无法执行安装
微信小程序源码获取与反编译方式
一文速学-玩转MySQL获取时间、格式转换各类操作方法详解
通信原理——纠错编码 | 汉明码(海明码)手算详解
js中的对象与函数的理解
C语言 | 获取字符串里逗号间隔的内容
禅道安装及使用教程
Attribute Changer的几种形态
cocoscreator3.5.2打包微信小游戏发布到QQ小游戏修改