当前位置:网站首页>PyTorch项目实战—FashionMNIST时装分类
PyTorch项目实战—FashionMNIST时装分类
2022-07-25 11:37:00 【Alexa2077】
一,基于pytorch的Fashion-MNIST时装分类流程
本文主要代码与文字来自于DataWhale团队,深入浅出PyTorch课程。
参考链接:https://datawhalechina.github.io/thorough-pytorch
1,任务介绍
任务介绍:对10个类别的时装图像进行分类,使用FashionMNIST数据集。如下图所示为若干样例图,每个图对应一个样本。
样本介绍:FashionMNIST数据集中包含已经预先划分好的训练集和测试集,其中训练集共60,000张图像,测试集共10,000张图像。每张图像均为单通道黑白图像,大小为28*28pixel,分属10个类别。
2,分类流程
1-导包与超参数配置:基本流程与上一节:Pytorch主要模块类似。注意对于windows用户,可以把num_workers设置为0.
# 导包
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# 配置GPU,这里有两种方式
## 方案一:使用os.environ
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
## 配置其他超参数,如batch_size, num_workers, learning rate, 以及总的epochs
batch_size = 256
num_workers = 4 # 对于Windows用户,这里应设置为0,否则会出现多线程错误
lr = 1e-4
epochs = 20
2-数据读入:两种方式;
- 下载并使用PyTorch提供的内置数据集。
- 从网站下载以csv格式存储的数据,读入并转成预期的格式
## 读取方式一:使用torchvision自带数据集,下载可能需要一段时间
from torchvision import datasets
train_data = datasets.FashionMNIST(root='./', train=True, download=True, transform=data_transform)
test_data = datasets.FashionMNIST(root='./', train=False, download=True, transform=data_transform)
## 读取方式二:读入csv格式的数据,自行构建Dataset类
# csv数据下载链接:https://www.kaggle.com/zalando-research/fashionmnist
class FMDataset(Dataset):
def __init__(self, df, transform=None):
self.df = df
self.transform = transform
self.images = df.iloc[:,1:].values.astype(np.uint8)
self.labels = df.iloc[:, 0].values
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx].reshape(28,28,1)
label = int(self.labels[idx])
if self.transform is not None:
image = self.transform(image)
else:
image = torch.tensor(image/255., dtype=torch.float)
label = torch.tensor(label, dtype=torch.long)
return image, label
train_df = pd.read_csv("./FashionMNIST/fashion-mnist_train.csv")
test_df = pd.read_csv("./FashionMNIST/fashion-mnist_test.csv")
train_data = FMDataset(train_df, data_transform)
test_data = FMDataset(test_df, data_transform)
第一种数据读入方式只适用于常见的数据集,如MNIST,CIFAR10等,PyTorch官方提供了数据下载。这种方式往往适用于快速测试方法(比如测试下某个idea在MNIST数据集上是否有效)
第二种数据读入方式需要自己构建Dataset,这对于PyTorch应用于自己的工作中十分重要
3-数据预处理:数据读入之后,需要处理成符合模型输入要求的数据格式。
比如说需要将图片统一为一致的大小,以便后续能够输入网络训练;需要将数据格式转为Tensor类,等等。这些变换可以很方便地借助torchvision包来完成,这是PyTorch官方用于图像处理的工具库,上面提到的使用内置数据集的方式也要用到。
# 首先设置数据变换
from torchvision import transforms
image_size = 28
data_transform = transforms.Compose([
transforms.ToPILImage(),
# 这一步取决于后续的数据读取方式,如果使用内置数据集读取方式则不需要
transforms.Resize(image_size),
transforms.ToTensor()
])
4-在构建训练和测试数据集完成后,需要定义DataLoader类,以便在训练和测试时加载数据
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
读入后,我们可以做一些数据可视化操作,主要是验证我们读入的数据是否正确
import matplotlib.pyplot as plt
image, label = next(iter(train_loader))
print(image.shape, label.shape)
plt.imshow(image[0][0], cmap="gray")
到这里可以打印输出,查看是否有正确输入。
输出结果如下:
torch.Size([256, 1, 28, 28])
torch.Size([256])
5-模型设计:搭建CNN,放到GPU训练
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Dropout(0.3),
nn.Conv2d(32, 64, 5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Dropout(0.3)
)
self.fc = nn.Sequential(
nn.Linear(64*4*4, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 64*4*4)
x = self.fc(x)
# x = nn.functional.normalize(x)
return x
model = Net()
model = model.cuda()
# model = nn.DataParallel(model).cuda() # 多卡训练时的写法,之后的课程中会进一步讲解
6-设定损失函数和优化器:使用torch.nn模块自带的CrossEntropy损失PyTorch会自动把整数型的label转为one-hot型,用于计算CE loss这里需要确保label是从0开始的,同时模型不加softmax层(使用logits计算),这也说明了PyTorch训练中各个部分不是独立的,需要通盘考虑。使用Adam优化器。
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
7-训练和验证:各自封装成函数,方便后续调用。
对于训练:
def train(epoch):
model.train()
train_loss = 0
for data, label in train_loader:
data, label = data.cuda(), label.cuda()
optimizer.zero_grad() #梯度变0,不让梯度进行累加
output = model(data)
loss = criterion(output, label)
loss.backward()
optimizer.step()
train_loss += loss.item()*data.size(0)
train_loss = train_loss/len(train_loader.dataset)
print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
对于验证:
def val(epoch):
model.eval() #验证
val_loss = 0
gt_labels = []
pred_labels = []
with torch.no_grad(): # 不进行梯度计算
for data, label in test_loader:
data, label = data.cuda(), label.cuda()
output = model(data)
preds = torch.argmax(output, 1)
gt_labels.append(label.cpu().data.numpy())
pred_labels.append(preds.cpu().data.numpy())
loss = criterion(output, label) # 损失不回传
val_loss += loss.item()*data.size(0)
val_loss = val_loss/len(test_loader.dataset)
gt_labels, pred_labels = np.concatenate(gt_labels), np.concatenate(pred_labels)
acc = np.sum(gt_labels==pred_labels)/len(pred_labels)
print('Epoch: {} \tValidation Loss: {:.6f}, Accuracy: {:6f}'.format(epoch, val_loss, acc))
for epoch in range(1, epochs+1):
train(epoch)
val(epoch)
结果输出:精度达到92%
8-模型保存:训练完成后,可以使用torch.save保存模型参数或者整个模型,也可以在训练过程中保存模型
save_path = "./FahionModel.pkl"
torch.save(model, save_path)
二,基于PyTorch的实战项目2
1,项目实战2
将仿照时装分类,自行寻找一个项目进行训练实践。
文章链接:待填坑!
边栏推荐
- 【AI4Code】《Contrastive Code Representation Learning》 (EMNLP 2021)
- 【AI4Code最终章】AlphaCode:《Competition-Level Code Generation with AlphaCode》(DeepMind)
- 防范SYN洪泛攻击的方法 -- SYN cookie
- [RS sampling] a gain tuning dynamic negative sampler for recommendation (WWW 2022)
- Fault tolerant mechanism record
- Dr. water 2
- scrapy 爬虫框架简介
- 【GCN-RS】Learning Explicit User Interest Boundary for Recommendation (WWW‘22)
- 给生活加点惊喜,做创意生活的原型设计师丨编程挑战赛 x 选手分享
- Eureka使用记录
猜你喜欢
![[dark horse morning post] eBay announced its shutdown after 23 years of operation; Wei Lai throws an olive branch to Volkswagen CEO; Huawei's talented youth once gave up their annual salary of 3.6 mil](/img/d7/4671b5a74317a8f87ffd36be2b34e1.jpg)
[dark horse morning post] eBay announced its shutdown after 23 years of operation; Wei Lai throws an olive branch to Volkswagen CEO; Huawei's talented youth once gave up their annual salary of 3.6 mil

Brpc source code analysis (IV) -- bthread mechanism

Resttemplate and ribbon are easy to use

客户端开放下载, 欢迎尝鲜

Eureka使用记录

通信总线协议一 :UART

【GCN-RS】Towards Representation Alignment and Uniformity in Collaborative Filtering (KDD‘22)

Knowledge maps are used to recommend system problems (mvin, Ctrl, ckan, Kred, gaeat)
![[multimodal] hit: hierarchical transformer with momentum contract for video text retrieval iccv 2021](/img/48/d5ec2b80cd949b359bcb0bcf08f4eb.png)
[multimodal] hit: hierarchical transformer with momentum contract for video text retrieval iccv 2021

【AI4Code】《CoSQA: 20,000+ Web Queries for Code Search and Question Answering》 ACL 2021
随机推荐
氢能创业大赛 | 国家能源局科技司副司长刘亚芳:构建高质量创新体系是我国氢能产业发展的核心
Application and innovation of low code technology in logistics management
R language ggplot2 visualization: use the ggstripchart function of ggpubr package to visualize the dot strip chart, set the palette parameter to configure the color of data points at different levels,
Intelligent information retrieval (overview of intelligent information retrieval)
aaaaaaaaaaA heH heH nuN
Those young people who left Netease
1.1.1 欢迎来到机器学习
GPT plus money (OpenAI CLIP,DALL-E)
循环创建目录与子目录
Power Bi -- these skills make the report more "compelling"“
NLP知识----pytorch,反向传播,预测型任务的一些小碎块笔记
R language ggplot2 visualization: visualize the scatter diagram, add text labels to some data points in the scatter diagram, and use geom of ggrep package_ text_ The repl function avoids overlapping l
3.2.1 what is machine learning?
[high concurrency] I summarized the best learning route of concurrent programming with 10 diagrams!! (recommended Collection)
和特朗普吃了顿饭后写下了这篇文章
Brpc source code analysis (IV) -- bthread mechanism
记录一次线上死锁的定位分析
R language Visual scatter diagram, geom using ggrep package_ text_ The rep function avoids overlapping labels between data points (set the min.segment.length parameter to inf and do not add label segm
【GCN-RS】Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for RS (SIGIR‘22)
【黑马早报】运营23年,易趣网宣布关停;蔚来对大众CEO抛出橄榄枝;华为天才少年曾放弃360万年薪;尹烨回应饶毅炮轰其伪科学...