当前位置:网站首页>MNIST implementation using pytoch in jupyter notebook
MNIST implementation using pytoch in jupyter notebook
2022-07-06 10:25:00 【How about a song without trace】
" technological process "
#1、 Load necessary Libraries
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
from torchvision import datasets , transforms
#2、 Define super parameters
BATCH_SIZE = 16 # Data per batch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 10 # Rounds of training data
#3、 structure pipline, Image processing
pipline = transforms.Compose({
transforms.ToTensor(),# Convert the picture to tensor
transforms.Normalize((0.1307),(0.3081)) # Regularization , When the model is over fitted , Reduce model complexity
})
#4、 Download load data
from torch.utils.data import DataLoader
train_set = datasets.MNIST("data2",train=True,download=True,transform=pipline)
test_set = datasets.MNIST("data2",train=False,download=True,transform=pipline)
# Load training dataset
train_loader = DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)
# Load test data set
test_loader = DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
# Show the pictures
with open("./data2/MNIST/raw/train-images-idx3-ubyte","rb") as f:
file = f.read()
image1 = [int(str(item).encode("ascii") ,16) for item in file[16:16+784]]
print(image1)
import cv2
import numpy as np
image1_np = np.array(image1,dtype = np.uint8).reshape(28,28,1)
print(image1_np.shape)
# Save the picture
cv2.imwrite('digit.jpg',image1_np)
#5、 Build a network model
class Digit(nn.Module):
def __init__(self): # Construction method
super().__init__() # Call the constructor of the parent class , Inherit the properties of the parent class
self.conv1 = nn.Conv2d(1,10,5) # The input channel is 1, The output channel is 10, Convolution kernels for 5( This is a 5*5 Of )
self.conv2 = nn.Conv2d(10,20,3) # The output of the upper layer is the input of the lower layer
self.fc1 = nn.Linear(20*10*10,500) #20*10*10 The total number of input channels , 500 Output channel
self.fc2 = nn.Linear(500,10) #10 in total 10 The probability of categories
def forward(self,x):
input_size = x.size(0) # The tensor form of the whole picture is batch_size*1*28*28 , So get it directly batch_size
x = self.conv1(x) # Input :batch_size*1*28*28 Output :batch_size*10*24*24 This 10 Is the number of output channels of the first convolution layer ,24 = 28-5+1
x = F.relu(x) # keep shape unchanged , Output batch_size*10*24*24
x = F.max_pool2d(x,2,2) # Input batch_size*10*24*24 Output :batch_size*10*12( halve )*12 # Pooling layer : Compress the picture ( Downsampling ) Extract the most significant features
x = self.conv2(x) # Input :batch_size*10*12*12 Output :batch_size*20*10*10(12-3+1)
x = F.relu(x)
x = x.view(input_size,-1) # The tensile , Or even , This -1 Automatically calculate dimensions This -1 In fact, his value is 20*10*10=2000 Dimensions
x = self.fc1(x) # Input batch_size*2000 Output batch_size*500
x = F.relu(x) # keep shape unchanged
x = self.fc2(x) # Input :batch_size*500 Output :batch_size*10
output = F.log_softmax(x,dim=1) # Loss function After calculation and classification , The probability value of each number
return output
#6、 Define optimizer
model =Digit().to(DEVICE)
optimizer = optim.Adam(model.parameters())
#7、 Define training methods
def train_model(model,device,train_loader,optimizer,epoch):
# model training
model.train()
for batch_index,(data,target) in enumerate(train_loader):
# Deploy to DEVICE Up
data,target = data.to(device),target.to(device)
# The gradient is initialized to 0
optimizer.zero_grad()
# The results after training
output = model(data)
# Calculate the loss
loss = F.cross_entropy(output, target) # The cross entropy loss function is suitable for multi classification tasks
# Find the subscript with the largest probability value
pred = output.max(1,keepdim = True) #1 Represents the horizontal axis You can also write like this pred = output.argmax(dim=1)
# Back propagation
loss.backward()
# Parameter optimization , That is, every parameter update
optimizer.step()
if batch_index % 3000 == 0: # Every processing 3000 Print a picture once
print("Train Epoch :{}\tLOSS : {:.6f}".format(epoch,loss.item())) # This loss It has to be followed by item(), Get the value
#8、 Define test methods
def test_model(model,device,test_loader):
# Model validation
model.eval()
# Accuracy rate
correct = 0.0
# Test loss
test_loss = 0.0
with torch.no_grad(): # No gradient calculation , There will be no back propagation
for data, target in test_loader:
# Deploy to device Up
data,target = data.to(device),target.to(device)
# Test data
output = model(data)
# Calculate the test loss
test_loss += F.cross_entropy(output, target).item()
# Find the subscript of the maximum probability
pred = output.argmax(dim = 1)
# Accumulate correct values
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print("Test-- Average loss {:.4f},Accuracy : {:.3f}\n".format(
test_loss, 100.0 *correct / len(test_loader.dataset)
))
# 9、 Call training and testing methods
for epoch in range(1,EPOCHS + 1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)边栏推荐
- MySQL ERROR 1040: Too many connections
- 高并发系统的限流方案研究,其实限流实现也不复杂
- 第一篇博客
- MySQL实战优化高手04 借着更新语句在InnoDB存储引擎中的执行流程,聊聊binlog是什么?
- Software test engineer development planning route
- MySQL實戰優化高手08 生產經驗:在數據庫的壓測過程中,如何360度無死角觀察機器性能?
- MySQL实战优化高手08 生产经验:在数据库的压测过程中,如何360度无死角观察机器性能?
- MySQL combat optimization expert 12 what does the memory data structure buffer pool look like?
- A necessary soft skill for Software Test Engineers: structured thinking
- 四川云教和双师模式
猜你喜欢

MySQL底层的逻辑架构

Contest3145 - the 37th game of 2021 freshman individual training match_ C: Tour guide

South China Technology stack cnn+bilstm+attention

MySQL的存储引擎

cmooc互联网+教育

数据库中间件_Mycat总结

Super detailed steps to implement Wechat public number H5 Message push

解决在window中远程连接Linux下的MySQL

Solve the problem of remote connection to MySQL under Linux in Windows

MySQL34-其他数据库日志
随机推荐
评估方法的优缺点
A new understanding of RMAN retention policy recovery window
MySQL实战优化高手11 从数据的增删改开始讲起,回顾一下Buffer Pool在数据库里的地位
基于Pytorch的LSTM实战160万条评论情感分类
Write your own CPU Chapter 10 - learning notes
The appearance is popular. Two JSON visualization tools are recommended for use with swagger. It's really fragrant
MySQL ERROR 1040: Too many connections
NLP routes and resources
UEditor国际化配置,支持中英文切换
Target detection -- yolov2 paper intensive reading
① BOKE
再有人问你数据库缓存一致性的问题,直接把这篇文章发给他
安装OpenCV时遇到的几种错误
Pytorch LSTM实现流程(可视化版本)
Notes of Dr. Carolyn ROS é's social networking speech
【C语言】深度剖析数据存储的底层原理
Preliminary introduction to C miscellaneous lecture document
简单解决phpjm加密问题 免费phpjm解密工具
Safety notes
A necessary soft skill for Software Test Engineers: structured thinking