当前位置:网站首页>MNIST implementation using pytoch in jupyter notebook
MNIST implementation using pytoch in jupyter notebook
2022-07-06 10:25:00 【How about a song without trace】
" technological process "
#1、 Load necessary Libraries
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
from torchvision import datasets , transforms
#2、 Define super parameters
BATCH_SIZE = 16 # Data per batch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 10 # Rounds of training data
#3、 structure pipline, Image processing
pipline = transforms.Compose({
transforms.ToTensor(),# Convert the picture to tensor
transforms.Normalize((0.1307),(0.3081)) # Regularization , When the model is over fitted , Reduce model complexity
})
#4、 Download load data
from torch.utils.data import DataLoader
train_set = datasets.MNIST("data2",train=True,download=True,transform=pipline)
test_set = datasets.MNIST("data2",train=False,download=True,transform=pipline)
# Load training dataset
train_loader = DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)
# Load test data set
test_loader = DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
# Show the pictures
with open("./data2/MNIST/raw/train-images-idx3-ubyte","rb") as f:
file = f.read()
image1 = [int(str(item).encode("ascii") ,16) for item in file[16:16+784]]
print(image1)
import cv2
import numpy as np
image1_np = np.array(image1,dtype = np.uint8).reshape(28,28,1)
print(image1_np.shape)
# Save the picture
cv2.imwrite('digit.jpg',image1_np)
#5、 Build a network model
class Digit(nn.Module):
def __init__(self): # Construction method
super().__init__() # Call the constructor of the parent class , Inherit the properties of the parent class
self.conv1 = nn.Conv2d(1,10,5) # The input channel is 1, The output channel is 10, Convolution kernels for 5( This is a 5*5 Of )
self.conv2 = nn.Conv2d(10,20,3) # The output of the upper layer is the input of the lower layer
self.fc1 = nn.Linear(20*10*10,500) #20*10*10 The total number of input channels , 500 Output channel
self.fc2 = nn.Linear(500,10) #10 in total 10 The probability of categories
def forward(self,x):
input_size = x.size(0) # The tensor form of the whole picture is batch_size*1*28*28 , So get it directly batch_size
x = self.conv1(x) # Input :batch_size*1*28*28 Output :batch_size*10*24*24 This 10 Is the number of output channels of the first convolution layer ,24 = 28-5+1
x = F.relu(x) # keep shape unchanged , Output batch_size*10*24*24
x = F.max_pool2d(x,2,2) # Input batch_size*10*24*24 Output :batch_size*10*12( halve )*12 # Pooling layer : Compress the picture ( Downsampling ) Extract the most significant features
x = self.conv2(x) # Input :batch_size*10*12*12 Output :batch_size*20*10*10(12-3+1)
x = F.relu(x)
x = x.view(input_size,-1) # The tensile , Or even , This -1 Automatically calculate dimensions This -1 In fact, his value is 20*10*10=2000 Dimensions
x = self.fc1(x) # Input batch_size*2000 Output batch_size*500
x = F.relu(x) # keep shape unchanged
x = self.fc2(x) # Input :batch_size*500 Output :batch_size*10
output = F.log_softmax(x,dim=1) # Loss function After calculation and classification , The probability value of each number
return output
#6、 Define optimizer
model =Digit().to(DEVICE)
optimizer = optim.Adam(model.parameters())
#7、 Define training methods
def train_model(model,device,train_loader,optimizer,epoch):
# model training
model.train()
for batch_index,(data,target) in enumerate(train_loader):
# Deploy to DEVICE Up
data,target = data.to(device),target.to(device)
# The gradient is initialized to 0
optimizer.zero_grad()
# The results after training
output = model(data)
# Calculate the loss
loss = F.cross_entropy(output, target) # The cross entropy loss function is suitable for multi classification tasks
# Find the subscript with the largest probability value
pred = output.max(1,keepdim = True) #1 Represents the horizontal axis You can also write like this pred = output.argmax(dim=1)
# Back propagation
loss.backward()
# Parameter optimization , That is, every parameter update
optimizer.step()
if batch_index % 3000 == 0: # Every processing 3000 Print a picture once
print("Train Epoch :{}\tLOSS : {:.6f}".format(epoch,loss.item())) # This loss It has to be followed by item(), Get the value
#8、 Define test methods
def test_model(model,device,test_loader):
# Model validation
model.eval()
# Accuracy rate
correct = 0.0
# Test loss
test_loss = 0.0
with torch.no_grad(): # No gradient calculation , There will be no back propagation
for data, target in test_loader:
# Deploy to device Up
data,target = data.to(device),target.to(device)
# Test data
output = model(data)
# Calculate the test loss
test_loss += F.cross_entropy(output, target).item()
# Find the subscript of the maximum probability
pred = output.argmax(dim = 1)
# Accumulate correct values
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print("Test-- Average loss {:.4f},Accuracy : {:.3f}\n".format(
test_loss, 100.0 *correct / len(test_loader.dataset)
))
# 9、 Call training and testing methods
for epoch in range(1,EPOCHS + 1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)
边栏推荐
- ByteTrack: Multi-Object Tracking by Associating Every Detection Box 论文阅读笔记()
- 再有人问你数据库缓存一致性的问题,直接把这篇文章发给他
- MySQL實戰優化高手08 生產經驗:在數據庫的壓測過程中,如何360度無死角觀察機器性能?
- MySQL storage engine
- Several errors encountered when installing opencv
- 15 医疗挂号系统_【预约挂号】
- docker MySQL解决时区问题
- Constants and pointers
- Not registered via @EnableConfigurationProperties, marked(@ConfigurationProperties的使用)
- 好博客好资料记录链接
猜你喜欢
Jar runs with error no main manifest attribute
Installation of pagoda and deployment of flask project
If someone asks you about the consistency of database cache, send this article directly to him
保姆级手把手教你用C语言写三子棋
Export virtual machines from esxi 6.7 using OVF tool
Sichuan cloud education and double teacher model
MySQL combat optimization expert 04 uses the execution process of update statements in the InnoDB storage engine to talk about what binlog is?
jar运行报错no main manifest attribute
数据库中间件_Mycat总结
14 medical registration system_ [Alibaba cloud OSS, user authentication and patient]
随机推荐
MySQL combat optimization expert 02 in order to execute SQL statements, do you know what kind of architectural design MySQL uses?
jar运行报错no main manifest attribute
vscode 常用的指令
解决在window中远程连接Linux下的MySQL
What is the difference between TCP and UDP?
UEditor国际化配置,支持中英文切换
Contest3145 - the 37th game of 2021 freshman individual training match_ B: Password
该不会还有人不懂用C语言写扫雷游戏吧
软件测试工程师必备之软技能:结构化思维
【C语言】深度剖析数据存储的底层原理
The 32-year-old fitness coach turned to a programmer and got an offer of 760000 a year. The experience of this older coder caused heated discussion
MySQL34-其他数据库日志
CDC: the outbreak of Listeria monocytogenes in the United States is related to ice cream products
实现微信公众号H5消息推送的超级详细步骤
MySQL real battle optimization expert 08 production experience: how to observe the machine performance 360 degrees without dead angle in the process of database pressure test?
MySQL combat optimization expert 12 what does the memory data structure buffer pool look like?
NLP路线和资源
C杂讲 双向循环链表
再有人问你数据库缓存一致性的问题,直接把这篇文章发给他
Chrome浏览器端跨域不能访问问题处理办法