当前位置:网站首页>PyTorch搭建LSTM实现服装分类(FashionMNIST)
PyTorch搭建LSTM实现服装分类(FashionMNIST)
2022-07-02 09:41:00 【raelum】
FashionMNIST 数据集官网:https://github.com/zalandoresearch/fashion-mnist.
这里不再介绍该数据集,如需了解请前往官网。
思路: 数据集中的每张图片都是尺寸为 ( 28 , 28 ) (28,28) (28,28) 的灰度图。我们可以将其看作 28 × 28 28\times28 28×28 的数字矩阵,将该矩阵按行进行逐行分块可得一个长度为 28 28 28 的序列,且序列中的每个 “词元” 对应的特征维数也是 28 28 28。
运行环境:
- 系统:Ubuntu 20.04;
- GPU:RTX 3090;
- Pytorch:1.11;
- Python:3.8
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
# Data Preprocessing
train_data = torchvision.datasets.FashionMNIST(root='data',
train=True,
transform=torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.FashionMNIST(root='data',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=64, num_workers=4)
# Model building
class LSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(28, 64, num_layers=2)
self.linear = nn.Linear(64, 10)
def forward(self, x):
output, (h_n, c_n) = self.lstm(x, None)
return self.linear(h_n[0])
def setup_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Setup
setup_seed(42)
NUM_EPOCHS = 20
LR = 4e-3
train_loss, test_loss, test_acc = [], [], []
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lstm = LSTM()
lstm.to(device)
critertion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm.parameters(), lr=LR)
# Training and testing
for epoch in range(NUM_EPOCHS):
print(f'[Epoch {
epoch + 1}]', end=' ')
avg_train_loss, avg_test_loss, correct = 0, 0, 0
# train
lstm.train()
for batch_idx, (X, y) in enumerate(train_loader):
# (64, 1, 28, 28) -> (28, 64, 28)
X = X.squeeze().movedim(0, 1)
X, y = X.to(device), y.to(device)
# forward
output = lstm(X)
loss = critertion(output, y)
avg_train_loss += loss
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_train_loss /= (batch_idx + 1)
train_loss.append(avg_train_loss.item())
# test
lstm.eval()
with torch.no_grad():
for batch_idx, (X, y) in enumerate(test_loader):
X = X.squeeze().movedim(0, 1)
X, y = X.to(device), y.to(device)
pred = lstm(X)
loss = critertion(pred, y)
avg_test_loss += loss
correct += (pred.argmax(1) == y).sum().item()
avg_test_loss /= (batch_idx + 1)
test_loss.append(avg_test_loss.item())
correct /= len(test_loader.dataset)
test_acc.append(correct)
print(
f"train loss: {
train_loss[-1]:.4f} | test loss: {
test_loss[-1]:.4f} | test acc: {
correct:.4f}"
)
# Plot
x = np.arange(1, 21)
plt.plot(x, train_loss, label="train loss")
plt.plot(x, test_loss, label="test loss")
plt.plot(x, test_acc, label="test acc")
plt.xlabel("epoch")
plt.legend(loc="best", fontsize=12)
plt.show()
输出结果:
[Epoch 1] train loss: 0.6602 | test loss: 0.5017 | test acc: 0.8147
[Epoch 2] train loss: 0.4089 | test loss: 0.3979 | test acc: 0.8566
[Epoch 3] train loss: 0.3577 | test loss: 0.3675 | test acc: 0.8669
[Epoch 4] train loss: 0.3268 | test loss: 0.3509 | test acc: 0.8751
[Epoch 5] train loss: 0.3098 | test loss: 0.3395 | test acc: 0.8752
[Epoch 6] train loss: 0.2962 | test loss: 0.3135 | test acc: 0.8854
[Epoch 7] train loss: 0.2823 | test loss: 0.3377 | test acc: 0.8776
[Epoch 8] train loss: 0.2720 | test loss: 0.3196 | test acc: 0.8835
[Epoch 9] train loss: 0.2623 | test loss: 0.3120 | test acc: 0.8849
[Epoch 10] train loss: 0.2547 | test loss: 0.2981 | test acc: 0.8931
[Epoch 11] train loss: 0.2438 | test loss: 0.3140 | test acc: 0.8882
[Epoch 12] train loss: 0.2372 | test loss: 0.3043 | test acc: 0.8909
[Epoch 13] train loss: 0.2307 | test loss: 0.2977 | test acc: 0.8918
[Epoch 14] train loss: 0.2219 | test loss: 0.2888 | test acc: 0.8970
[Epoch 15] train loss: 0.2187 | test loss: 0.2946 | test acc: 0.8959
[Epoch 16] train loss: 0.2132 | test loss: 0.2894 | test acc: 0.8985
[Epoch 17] train loss: 0.2061 | test loss: 0.2835 | test acc: 0.9014
[Epoch 18] train loss: 0.2028 | test loss: 0.2954 | test acc: 0.8971
[Epoch 19] train loss: 0.1966 | test loss: 0.2952 | test acc: 0.8986
[Epoch 20] train loss: 0.1922 | test loss: 0.2910 | test acc: 0.9011
相应的曲线:
一些心得 :
- 切勿直接使用
X = X.reshape(28, -1, 28)
,否则X
对应的将不是原来的图片(读者可自行尝试使用torchvision.transforms.ToPILImage
去输出X
对应的图片观察效果)。 - 学习率相同的情况下,SGD 的效果没有 Adam 好。
边栏推荐
- BEAUTIFUL GGPLOT VENN DIAGRAM WITH R
- vant tabs组件选中第一个下划线位置异常
- Always report errors when connecting to MySQL database
- Implementation of address book (file version)
- 可昇級合約的原理-DelegateCall
- GGPlot Examples Best Reference
- xss-labs-master靶场环境搭建与1-6关解题思路
- 可升级合约的原理-DelegateCall
- K-Means Clustering Visualization in R: Step By Step Guide
- 程序员成长第六篇:如何选择公司?
猜你喜欢
Some problems encountered in introducing lvgl into esp32 Arduino
YYGH-BUG-05
Three transparent LED displays that were "crowded" in 2022
How to Visualize Missing Data in R using a Heatmap
How to Easily Create Barplots with Error Bars in R
PgSQL string is converted to array and associated with other tables, which are displayed in the original order after matching and splicing
The position of the first underline selected by the vant tabs component is abnormal
[visual studio 2019] create MFC desktop program (install MFC development components | create MFC application | edit MFC application window | add click event for button | Modify button text | open appl
vant tabs组件选中第一个下划线位置异常
Redis exceeds the maximum memory error oom command not allowed when used memory & gt; ' maxmemory'
随机推荐
动态内存(进阶四)
to_ Bytes and from_ Bytes simple example
电脑无缘无故黑屏,无法调节亮度。
On April 17, 2022, the five heart matchmaker team received double good news
制造业数字化转型和精益生产什么关系
【2022 ACTF-wp】
[visual studio 2019] create MFC desktop program (install MFC development components | create MFC application | edit MFC application window | add click event for button | Modify button text | open appl
Principe du contrat évolutif - delegatecall
PX4 Position_ Control RC_ Remoter import
Visualization of chip SEQ data by deeptools
HOW TO ADD P-VALUES TO GGPLOT FACETS
How to Create a Beautiful Plots in R with Summary Statistics Labels
程序员成长第六篇:如何选择公司?
Precautions for scalable contract solution based on openzeppelin
Flesh-dect (media 2021) -- a viewpoint of material decomposition
Cmake cross compilation
Tidb DM alarm DM_ sync_ process_ exists_ with_ Error troubleshooting
C#基于当前时间,获取唯一识别号(ID)的方法
基于Hardhat和Openzeppelin开发可升级合约(一)
揭露数据不一致的利器 —— 实时核对系统