当前位置:网站首页>Station B Liu Erden softmx classifier and MNIST implementation -structure 9
Station B Liu Erden softmx classifier and MNIST implementation -structure 9
2022-07-06 05:42:00 【Ning Ranye】
Series articles :
List of articles
softmax classifier
Loss function : Cross entropy
Numpty Realize the cross entropy loss function
Pytorch Realized cross entropy loss
MNIST Realization
Guide pack
import torch
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
# Use relu()
import torch.nn.functional as F
# Construct optimizer
import torch.optim as optim
1- Prepare the data
# 1- Prepare the data
batch_size = 64
# take PIL Image capture and change to Tensor
transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms,
download=False)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms,
download=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
shuffle=False)
2- Design the network model
# 2- Design the network model
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.lay1 = torch.nn.Linear(784,512)
self.lay2 = torch.nn.Linear(512,256)
self.lay3 = torch.nn.Linear(256,128)
self.lay4 = torch.nn.Linear(128,64)
self.lay5 = torch.nn.Linear(64,10)
def forward(self,x):
x = x.view(-1,784)
x = F.relu(self.lay1(x))
x = F.relu(self.lay2(x))
x = F.relu(self.lay3(x))
x = F.relu(self.lay4(x))
x = F.relu(self.lay5(x))
return x
3- Build a model 、 Loss function 、 Optimizer
# 3- Construct loss function and optimizer
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005,momentum=0.5)
4- Training 、 test
# 4- Training test
def train(epoch):
running_loss = 0.0
# enumerate(train_loader, 0): batch_idx from 0 Count
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
optimizer.zero_grad()
# forward + backward + update
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if(batch_idx % 300 == 299):
print('[%d, %5d] loss: %.3f'%(epoch + 1, batch_idx + 1, running_loss/300))
running_loss = 0.0
def test():
correct = 0
total = 0
# The test does not need to generate a calculation diagram , No gradient update is required 、 Back propagation
with torch.no_grad():
# data yes len =2 Of list
# input yes data[0], target yes data[1]
for data in test_loader:
images, label = data
outputs = model(images)
# _ Is the maximum value returned , predicted Is the subscript corresponding to the maximum
_, predicted = torch.max(outputs.data, dim=1)
total += label.size(0)
correct += (predicted == label).sum().item()
print('Accutacy on test set : %d %%'%(100*correct/total))
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
About :(predicted == label).sum()
Will predicted Each element in is associated with the corresponding position label Opposite edge , Same back True, Different back False. .sum Seeking True The number of
inputs, target = data Explain the assignment
Complete code
import torch
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
# Use relu()
import torch.nn.functional as F
# Construct optimizer
import torch.optim as optim
# 1- Prepare the data
batch_size = 64
# take PIL Image capture and change to Tensor
transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms,
download=False)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms,
download=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
shuffle=False)
# 2- Design the network model
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.lay1 = torch.nn.Linear(784,512)
self.lay2 = torch.nn.Linear(512,256)
self.lay3 = torch.nn.Linear(256,128)
self.lay4 = torch.nn.Linear(128,64)
self.lay5 = torch.nn.Linear(64,10)
def forward(self,x):
x = x.view(-1,784)
x = F.relu(self.lay1(x))
x = F.relu(self.lay2(x))
x = F.relu(self.lay3(x))
x = F.relu(self.lay4(x))
x = F.relu(self.lay5(x))
return x
# 3- Construct loss function and optimizer
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005,momentum=0.5)
# 4- Training test
def train(epoch):
running_loss = 0.0
# enumerate(train_loader, 0): batch_idx from 0 Count
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
optimizer.zero_grad()
# forward + backward + update
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if(batch_idx % 300 == 299):
print('[%d, %5d] loss: %.3f'%(epoch + 1, batch_idx + 1, running_loss/300))
running_loss = 0.0
def test():
correct = 0
total = 0
# The test does not need to generate a calculation diagram , No gradient update is required 、 Back propagation
with torch.no_grad():
for data in test_loader:
images, label = data
outputs = model(images)
# _ Is the maximum value returned , predicted Is the subscript corresponding to the maximum
_, predicted = torch.max(outputs.data, dim=1)
total += label.size(0)
correct += (predicted == label).sum().item()
print('Accutacy on test set : %d %%'%(100*correct/total))
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
边栏推荐
猜你喜欢
无代码六月大事件|2022无代码探索者大会即将召开;AI增强型无代码工具推出...
Garbage collector with serial, throughput priority and response time priority
Vulhub vulnerability recurrence 68_ ThinkPHP
数字经济破浪而来 ,LTD是权益独立的Web3.0网站?
A master in the field of software architecture -- Reading Notes of the beauty of Architecture
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
B站刘二大人-数据集及数据加载 Lecture 8
【经验】win11上安装visio
移植InfoNES到STM32
注释、接续、转义等符号
随机推荐
Jvxetable用slot植入j-popup
Garbage collector with serial, throughput priority and response time priority
Node 之 nvm 下载、安装、使用,以及node 、nrm 的相关使用
LeetCode_字符串反转_简单_557. 反转字符串中的单词 III
实践分享:如何安全快速地从 Centos迁移到openEuler
How to download GB files from Google cloud hard disk
Remember an error in MySQL: the user specified as a definer ('mysql.infoschema '@' localhost ') does not exist
Web Security (V) what is a session? Why do I need a session?
26file filter anonymous inner class and lambda optimization
UCF (2022 summer team competition I)
Auto.js学习笔记17:基础监听事件和UI简单的点击事件操作
Questions d'examen écrit classiques du pointeur
数字经济破浪而来 ,LTD是权益独立的Web3.0网站?
Jvxetable implant j-popup with slot
[imgui] unity MenuItem shortcut key
Analysis of grammar elements in turtle Library
Summary of deep learning tuning tricks
Unity gets the width and height of Sprite
[Tang Laoshi] C -- encapsulation: classes and objects
YYGH-11-定时统计