当前位置:网站首页>【Kaggle】Classify Leaves
【Kaggle】Classify Leaves
2022-08-01 20:35:00 【鱼树(◔◡◔)】
目录
1、前言
作为第二个练手的题目,代码规范上有很多不清楚的地方,参考了好多其他大佬的分析流程和代码,在此做一下记录。
目前为止学过的(或者还在学)的网课:
小甲鱼python课程:link1
小土堆深度学习快速入门课程:link2
李宏毅机器学习课程:link3
李沐动手学深度学习课程:link4
李沐的书:link5
2、问题描述
这是b站李沐的动手学深度学习课程中提到的,目前准确率已经达到很高了,数据集分为3部分,一部分是很多叶子的图片,这些图片都是已经编号了的,训练集中给定了图片编号和其对应的种类,测试集中只给了编号而没有给种类。我们要做的就是根据已知种类的图片进行训练,之后判断测试集中编号对应的图片的叶子种类。
Kaggle链接:link6
3、代码实作
3.1 导入包
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms, models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
这里使用PIL中的Image来读取图片,torchvision中的transforms主要用来对图片进行变换进行增广
CPU跑图片可能比较吃力了,这里用上了电脑的GPU
3.2 数据预处理
# Data Loading
train = pd.read_csv(r'./classify-leaves/train.csv')
test = pd.read_csv(r'./classify-leaves/test.csv')
# Leave labels mapping
species_name_list = sorted(set(train['label'])) # set自动去重
species_to_num = dict(zip(species_name_list,
range(len(species_name_list))))
num_to_species = {
value:key for key, value in species_to_num.items()}
num_class = len(species_name_list) # 176
在数据预处理阶段,主要是将叶子的类别转换(字符串)为数字,并建立 "叶子类别:数字"和 “数字:叶子类别” 这两种字典,前者用于模型的训练,后者用于最终用于测试集时将结果转为具体叶子类别
3.3 Dataset
train_data_copy = train.copy()
test_data_copy = test.copy()
imagePath = r'./classify-leaves/images/'
train_data_len = train_data_copy.shape[0]
idx1 = [i for i in range(train_data_len) if i % 10 != 0]
idx2 = [i for i in range(train_data_len) if i % 10 == 0]
train_data = train_data_copy.iloc[idx1,:]
valid_data = train_data_copy.iloc[idx2,:]
这里划分数据集采用了比较简单的方式,将索引能被10整除的数据对作为验证集,其余的作为训练集
class MyDataset(Dataset):
def __init__(self, data, species_to_num, transform=None):
super().__init__()
self.data = data
self.species_to_num = species_to_num
self.labels = data['label'].values
self.transform = transform
def __getitem__(self, idx):
image = Image.open(imagePath +
self.data['image'].values[idx].split('/')[1])
if self.transform:
image = self.transform(image)
label = self.labels[idx]
label = self.species_to_num[label]
return image, label
def __len__(self):
return len(self.data)
class TestDataset(Dataset):
def __init__(self, data, transform=None):
super().__init__()
self.data = data
self.transform = transform
def __getitem__(self, idx):
image = Image.open(imagePath +
self.data['image'].values[idx].split('/')[1])
trans = transforms.ToTensor()
image = trans(image)
if self.transform:
image = transform(image)
return image
def __len__(self):
return len(self.data)
为了方便后续的训练、验证过程,这里继承Dataset类创建了MyDataset、TestDatsset
3.4 图像增广
data_transforms = {
'train':transforms.Compose(
[
# 随机裁剪图像,所得图像为原始面积的0.08~1,高宽比在3/4~4/3
# 缩放图像创建224x224的新图像
transforms.RandomResizedCrop(224, scale=(0.08, 1),
ratio=(3.0/4.0, 4.0/3.0)),
# 随机进行水平、垂直翻转
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
# 随机更改亮度,对比度和饱和度
transforms.ColorJitter(brightness=0.2,
contrast=0.2,
saturation=0.2),
transforms.ToTensor(),
# 标准化图像的每个通道
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]
),
'valid':transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]
)
}
对于训练集,对图像进行随机的旋转、更改亮度、对比度和饱和度等处理。在识别实际图片的时候,图片很可能是经过旋转或亮度变化的,所以在训练的时候进行图像增强处理有利于模型的泛化性。
3.5 训练模型
def get_net():
ret_net = nn.Sequential()
ret_net.features = models.resnet34(pretrained=True)
# 定义一个新的输出网络,共有176个输出类别
ret_net.output_new = nn.Sequential(
nn.Linear(1000, 300),
nn.ReLU(),
nn.Linear(300, 176))
ret_net = ret_net.to(device)
return ret_net
模型这里直接调用了pytorch训练好的renet34模型,为了使训练速度快一些,将pretrained设置为True,renet34最后一层将输出1000维,之后我们加入线性层将其转为176类输出
3.6 损失函数
loss = nn.CrossEntropyLoss()
3.7 评价函数
def evaluate_loss(net, dataloader):
valid_loss, valid_acc = [], []
net.eval()
for images, labels in dataloader:
with torch.no_grad():
images = images.to(device)
labels = labels.to(device)
output = net(images)
ls = loss(output, labels)
acc = (output.argmax(dim=-1) == labels).float().mean()
valid_loss.append(ls.item())
valid_acc.append(acc)
return sum(valid_loss)/len(valid_loss), sum(valid_acc)/len(valid_acc)
评价函数这里设置了两个指标,分别是 loss大小 和 准确率
3.8 训练函数
def train(num_epochs, lr, wd, batch_size):
net = get_net()
optimizer = torch.optim.Adam(net.parameters(),
lr=lr, weight_decay=wd)
train_dataset = MyDatase(train_data, species_to_num,
data_transforms['train'])
valid_dataset = MyDatase(valid_data, species_to_num,
data_transforms['valid'])
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=True)
best_acc = 0
global_step = 0
for epoch in range(num_epochs):
net.train()
train_loss, train_accs = [], []
step = 0
print(f'------第{
epoch + 1}轮训练开始------')
for imgs, labels in train_dataloader:
imgs = imgs.to(device)
labels = labels.to(device)
output = net(imgs)
ls = loss(output, labels)
optimizer.zero_grad()
ls.backward()
optimizer.step()
acc = (output.argmax(dim=-1) == labels).float().mean()
train_loss.append(ls.item())
train_accs.append(acc)
print(f' 第{
step+1}步训练结束,训练loss:{
train_loss[-1]}, 训练准确率:{
train_accs[-1]}')
step += 1
train_avgeloss = sum(train_loss) / len(train_loss)
train_avgeacc = sum(train_accs) / len(train_accs)
valid_avgeloss, valid_avgeacc = evaluate_loss(net, valid_dataloader)
print(f'第{
epoch+1}轮训练结束,训练集平均loss:{
train_avgeloss},平均准确率:{
train_avgeacc},'
f'验证集平均loss:{
valid_avgeloss},平均准确率:{
valid_avgeacc}')
if valid_avgeacc > best_acc:
best_acc = valid_avgeacc
global_step = 0
torch.save(net.state_dict(), r'./net.pth')
print(f'模型已保存')
else:
global_step += 1
if global_step > 300:
break
训练过程中如果在验证集上的准确率提升时将会把模型保存下来,用于之后的预测用。
3.9 训练
num_epochs = 10
learning_rate = 0.0001
weight_decay = 0.0001
batch_size = 64
train(num_epochs, lr=learning_rate, wd=weight_decay, batch_size=batch_size)

3.10 预测
test_dataset = TestDataset(test_data_copy)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False)
model = get_net()
model = model.to(device)
model.load_state_dict(torch.load('net.pth'))
preds = []
for image in test_dataloader:
with torch.no_grad():
pred = model(image.to(device))
preds.append(pred.argmax(dim=-1))
preds_copy = [j for i in preds for j in i]
preds_copy = [i.tolist() for i in preds_copy]
preds_speices = [num_to_species[i] for i in preds_copy]
test['label']=preds_speices
submission = pd.concat([test['image'], test['label']], axis=1)
submission.to_csv('./submission.csv',index=False)
将预测得到的概率利用argmax转为对应的索引(索引对应的就是相应的种类),之后将tensor转为列表,再利用之前创建的 “数字:叶子种类” 的字典将其转为对应的叶子种类,经过合并等处理后得到提交文件。
提交

这里我只训练了10轮,增加epoch数结果应该能达到一个比较好的水平
4、总结
几年前图像识别的准确率还是很低的,如今图像识别的准确率好像比人类还高了(人类识别图像有时候也会出错呢),发展速度还真是快啊。
边栏推荐
- 专利检索常用的网站有哪些?
- string
- 通配符 SSL/TLS 证书
- 微信小程序云开发|个人博客小程序
- tiup mirror merge
- [Energy Conservation Institute] Application of Intelligent Control Device in High Voltage Switchgear
- idea插件generateAllSetMethod一键生成set/get方法以及bean对象转换
- AQS原理和介绍
- SIPp installation and use
- New graduate students, great experience in reading English literature, worthy of your collection
猜你喜欢

LTE time domain and frequency domain resources

【个人作品】记之-串口日志记录工具

【节能学院】智能操控装置在高压开关柜的应用

Get started quickly with MongoDB

【torch】张量乘法:matmul,einsum

实用新型专利和发明专利的区别?秒懂!

宝塔搭建PESCMS-Ticket开源客服工单系统源码实测
![[Energy Conservation Institute] Application of Intelligent Control Device in High Voltage Switchgear](/img/6d/05233ce5c91a612b6247ea07d7982e.jpg)
[Energy Conservation Institute] Application of Intelligent Control Device in High Voltage Switchgear

【多任务学习】Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts KDD18

MongoDB快速上手
随机推荐
Zheng Xiangling, Chairman of Tide Pharmaceuticals, won the "2022 Outstanding Influential Entrepreneur Award" Tide Pharmaceuticals won the "Corporate Social Responsibility Model Award"
面试突击70:什么是粘包和半包?怎么解决?
密码学的基础:X.690和对应的BER CER DER编码
【节能学院】智能操控装置在高压开关柜的应用
通俗解释:什么是临床预测模型
Addition, Subtraction, Multiplication of Large Integers, Multiplication and Division of Large Integers and Ordinary Integers
Redis does web page UV statistics
Godaddy域名解析速度慢问题以及如何使用DNSPod解析解决
有用的网站
不同的操作加不同的锁详解
C语言实现-直接插入排序(带图详解)
Application of Acrel-5010 online monitoring system for key energy consumption unit energy consumption in Hunan Sanli Group
【无标题】
【kali-信息收集】(1.3)探测网络范围:DMitry(域名查询工具)、Scapy(跟踪路由工具)
【多任务学习】Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts KDD18
[Multi-task optimization] DWA, DTP, Gradnorm (CVPR 2019, ECCV 2018, ICML 2018)
LinkedList source code sharing
MongoDB快速上手
Hangao data import
面试突击70:什么是粘包和半包?怎么解决?