当前位置:网站首页>Pytorch builds LSTM to realize clothing classification (fashionmnist)
Pytorch builds LSTM to realize clothing classification (fashionmnist)
2022-07-02 11:57:00 【raelum】
FashionMNIST Data set official website :https://github.com/zalandoresearch/fashion-mnist.
This data set will not be introduced here , For more information, please go to the official website .
Ideas : Each picture in the dataset is of size ( 28 , 28 ) (28,28) (28,28) The gray image . We can think of it as 28 × 28 28\times28 28×28 Digital matrix of , Divide the matrix into rows Line by line and block by block A length of 28 28 28 Sequence , And each in the sequence “ Morpheme ” The corresponding characteristic dimension is also 28 28 28.
Running environment :
- System :Ubuntu 20.04;
- GPU:RTX 3090;
- Pytorch:1.11;
- Python:3.8
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
# Data Preprocessing
train_data = torchvision.datasets.FashionMNIST(root='data',
train=True,
transform=torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.FashionMNIST(root='data',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=64, num_workers=4)
# Model building
class LSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(28, 64, num_layers=2)
self.linear = nn.Linear(64, 10)
def forward(self, x):
output, (h_n, c_n) = self.lstm(x, None)
return self.linear(h_n[0])
def setup_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Setup
setup_seed(42)
NUM_EPOCHS = 20
LR = 4e-3
train_loss, test_loss, test_acc = [], [], []
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lstm = LSTM()
lstm.to(device)
critertion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm.parameters(), lr=LR)
# Training and testing
for epoch in range(NUM_EPOCHS):
print(f'[Epoch {
epoch + 1}]', end=' ')
avg_train_loss, avg_test_loss, correct = 0, 0, 0
# train
lstm.train()
for batch_idx, (X, y) in enumerate(train_loader):
# (64, 1, 28, 28) -> (28, 64, 28)
X = X.squeeze().movedim(0, 1)
X, y = X.to(device), y.to(device)
# forward
output = lstm(X)
loss = critertion(output, y)
avg_train_loss += loss
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_train_loss /= (batch_idx + 1)
train_loss.append(avg_train_loss.item())
# test
lstm.eval()
with torch.no_grad():
for batch_idx, (X, y) in enumerate(test_loader):
X = X.squeeze().movedim(0, 1)
X, y = X.to(device), y.to(device)
pred = lstm(X)
loss = critertion(pred, y)
avg_test_loss += loss
correct += (pred.argmax(1) == y).sum().item()
avg_test_loss /= (batch_idx + 1)
test_loss.append(avg_test_loss.item())
correct /= len(test_loader.dataset)
test_acc.append(correct)
print(
f"train loss: {
train_loss[-1]:.4f} | test loss: {
test_loss[-1]:.4f} | test acc: {
correct:.4f}"
)
# Plot
x = np.arange(1, 21)
plt.plot(x, train_loss, label="train loss")
plt.plot(x, test_loss, label="test loss")
plt.plot(x, test_acc, label="test acc")
plt.xlabel("epoch")
plt.legend(loc="best", fontsize=12)
plt.show()
Output results :
[Epoch 1] train loss: 0.6602 | test loss: 0.5017 | test acc: 0.8147
[Epoch 2] train loss: 0.4089 | test loss: 0.3979 | test acc: 0.8566
[Epoch 3] train loss: 0.3577 | test loss: 0.3675 | test acc: 0.8669
[Epoch 4] train loss: 0.3268 | test loss: 0.3509 | test acc: 0.8751
[Epoch 5] train loss: 0.3098 | test loss: 0.3395 | test acc: 0.8752
[Epoch 6] train loss: 0.2962 | test loss: 0.3135 | test acc: 0.8854
[Epoch 7] train loss: 0.2823 | test loss: 0.3377 | test acc: 0.8776
[Epoch 8] train loss: 0.2720 | test loss: 0.3196 | test acc: 0.8835
[Epoch 9] train loss: 0.2623 | test loss: 0.3120 | test acc: 0.8849
[Epoch 10] train loss: 0.2547 | test loss: 0.2981 | test acc: 0.8931
[Epoch 11] train loss: 0.2438 | test loss: 0.3140 | test acc: 0.8882
[Epoch 12] train loss: 0.2372 | test loss: 0.3043 | test acc: 0.8909
[Epoch 13] train loss: 0.2307 | test loss: 0.2977 | test acc: 0.8918
[Epoch 14] train loss: 0.2219 | test loss: 0.2888 | test acc: 0.8970
[Epoch 15] train loss: 0.2187 | test loss: 0.2946 | test acc: 0.8959
[Epoch 16] train loss: 0.2132 | test loss: 0.2894 | test acc: 0.8985
[Epoch 17] train loss: 0.2061 | test loss: 0.2835 | test acc: 0.9014
[Epoch 18] train loss: 0.2028 | test loss: 0.2954 | test acc: 0.8971
[Epoch 19] train loss: 0.1966 | test loss: 0.2952 | test acc: 0.8986
[Epoch 20] train loss: 0.1922 | test loss: 0.2910 | test acc: 0.9011
Corresponding curve :
some Experience :
- Do not use directly
X = X.reshape(28, -1, 28)
, otherwiseX
The corresponding image will not be the original one ( Readers can try to usetorchvision.transforms.ToPILImage
To outputX
The corresponding picture observation effect ). - With the same learning rate ,SGD The effect is not Adam good .
边栏推荐
- Pyqt5+opencv project practice: microcirculator pictures, video recording and manual comparison software (with source code)
- Flesh-dect (media 2021) -- a viewpoint of material decomposition
- R HISTOGRAM EXAMPLE QUICK REFERENCE
- 6. Introduce you to LED soft film screen. LED soft film screen size | price | installation | application
- Implementation of address book (file version)
- HOW TO EASILY CREATE BARPLOTS WITH ERROR BARS IN R
- K-Means Clustering Visualization in R: Step By Step Guide
- SVO2系列之深度滤波DepthFilter
- MySQL linked list data storage query sorting problem
- [visual studio 2019] create and import cmake project
猜你喜欢
How to Create a Beautiful Plots in R with Summary Statistics Labels
Redis exceeds the maximum memory error oom command not allowed when used memory & gt; ' maxmemory'
Mish-撼动深度学习ReLU激活函数的新继任者
How to Visualize Missing Data in R using a Heatmap
ESP32存储配网信息+LED显示配网状态+按键清除配网信息(附源码)
HOW TO ADD P-VALUES ONTO A GROUPED GGPLOT USING THE GGPUBR R PACKAGE
可昇級合約的原理-DelegateCall
基于Hardhat和Openzeppelin开发可升级合约(一)
The selected cells in Excel form have the selection effect of cross shading
How to Add P-Values onto Horizontal GGPLOTS
随机推荐
机械臂速成小指南(七):机械臂位姿的描述方法
BEAUTIFUL GGPLOT VENN DIAGRAM WITH R
R HISTOGRAM EXAMPLE QUICK REFERENCE
Dynamic debugging of multi file program x32dbg
电脑无缘无故黑屏,无法调节亮度。
A sharp tool for exposing data inconsistencies -- a real-time verification system
通讯录的实现(文件版本)
Easyexcel and Lombok annotations and commonly used swagger annotations
PyTorch中repeat、tile与repeat_interleave的区别
Flesh-dect (media 2021) -- a viewpoint of material decomposition
基于Hardhat和Openzeppelin开发可升级合约(二)
YYGH-10-微信支付
Esp32 audio frame esp-adf add key peripheral process code tracking
Thesis translation: 2022_ PACDNN: A phase-aware composite deep neural network for speech enhancement
ORB-SLAM2不同线程间的数据共享与传递
How to Create a Beautiful Plots in R with Summary Statistics Labels
Summary of flutter problems
Homer forecast motif
What week is a date obtained by QT
Is it safe to open a stock account through the QR code of the securities manager? Or is it safe to open an account in a securities company?