当前位置:网站首页>Station B Liu Erden softmx classifier and MNIST implementation -structure 9
Station B Liu Erden softmx classifier and MNIST implementation -structure 9
2022-07-06 05:42:00 【Ning Ranye】
Series articles :
List of articles
softmax classifier


Loss function : Cross entropy

Numpty Realize the cross entropy loss function

Pytorch Realized cross entropy loss 

MNIST Realization




Guide pack
import torch
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
# Use relu()
import torch.nn.functional as F
# Construct optimizer
import torch.optim as optim
1- Prepare the data
# 1- Prepare the data
batch_size = 64
# take PIL Image capture and change to Tensor
transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms,
download=False)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms,
download=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
shuffle=False)
2- Design the network model
# 2- Design the network model
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.lay1 = torch.nn.Linear(784,512)
self.lay2 = torch.nn.Linear(512,256)
self.lay3 = torch.nn.Linear(256,128)
self.lay4 = torch.nn.Linear(128,64)
self.lay5 = torch.nn.Linear(64,10)
def forward(self,x):
x = x.view(-1,784)
x = F.relu(self.lay1(x))
x = F.relu(self.lay2(x))
x = F.relu(self.lay3(x))
x = F.relu(self.lay4(x))
x = F.relu(self.lay5(x))
return x
3- Build a model 、 Loss function 、 Optimizer
# 3- Construct loss function and optimizer
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005,momentum=0.5)
4- Training 、 test
# 4- Training test
def train(epoch):
running_loss = 0.0
# enumerate(train_loader, 0): batch_idx from 0 Count
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
optimizer.zero_grad()
# forward + backward + update
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if(batch_idx % 300 == 299):
print('[%d, %5d] loss: %.3f'%(epoch + 1, batch_idx + 1, running_loss/300))
running_loss = 0.0
def test():
correct = 0
total = 0
# The test does not need to generate a calculation diagram , No gradient update is required 、 Back propagation
with torch.no_grad():
# data yes len =2 Of list
# input yes data[0], target yes data[1]
for data in test_loader:
images, label = data
outputs = model(images)
# _ Is the maximum value returned , predicted Is the subscript corresponding to the maximum
_, predicted = torch.max(outputs.data, dim=1)
total += label.size(0)
correct += (predicted == label).sum().item()
print('Accutacy on test set : %d %%'%(100*correct/total))
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
About :(predicted == label).sum()
Will predicted Each element in is associated with the corresponding position label Opposite edge , Same back True, Different back False. .sum Seeking True The number of 
inputs, target = data Explain the assignment 
Complete code
import torch
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
# Use relu()
import torch.nn.functional as F
# Construct optimizer
import torch.optim as optim
# 1- Prepare the data
batch_size = 64
# take PIL Image capture and change to Tensor
transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms,
download=False)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms,
download=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
shuffle=False)
# 2- Design the network model
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.lay1 = torch.nn.Linear(784,512)
self.lay2 = torch.nn.Linear(512,256)
self.lay3 = torch.nn.Linear(256,128)
self.lay4 = torch.nn.Linear(128,64)
self.lay5 = torch.nn.Linear(64,10)
def forward(self,x):
x = x.view(-1,784)
x = F.relu(self.lay1(x))
x = F.relu(self.lay2(x))
x = F.relu(self.lay3(x))
x = F.relu(self.lay4(x))
x = F.relu(self.lay5(x))
return x
# 3- Construct loss function and optimizer
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005,momentum=0.5)
# 4- Training test
def train(epoch):
running_loss = 0.0
# enumerate(train_loader, 0): batch_idx from 0 Count
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
optimizer.zero_grad()
# forward + backward + update
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if(batch_idx % 300 == 299):
print('[%d, %5d] loss: %.3f'%(epoch + 1, batch_idx + 1, running_loss/300))
running_loss = 0.0
def test():
correct = 0
total = 0
# The test does not need to generate a calculation diagram , No gradient update is required 、 Back propagation
with torch.no_grad():
for data in test_loader:
images, label = data
outputs = model(images)
# _ Is the maximum value returned , predicted Is the subscript corresponding to the maximum
_, predicted = torch.max(outputs.data, dim=1)
total += label.size(0)
correct += (predicted == label).sum().item()
print('Accutacy on test set : %d %%'%(100*correct/total))
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
边栏推荐
- ARTS Week 25
- 【华为机试真题详解】统计射击比赛成绩
- How to get list length
- Jushan database appears again in the gold fair to jointly build a new era of digital economy
- AUTOSAR从入门到精通番外篇(十)-嵌入式S19文件解析
- 03. 开发博客项目之登录
- The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
- C进阶-数据的存储(上)
- 指針經典筆試題
- 自建DNS服务器,客户端打开网页慢,解决办法
猜你喜欢

02. 开发博客项目之数据存储

How to use PHP string query function
![[experience] install Visio on win11](/img/f5/42bd597340d0aed9bfd13620bb0885.png)
[experience] install Visio on win11

Game push image / table /cv/nlp, multi-threaded start
![[Tang Laoshi] C -- encapsulation: classes and objects](/img/4e/30d2d4652ea2d4cd5fa7cbbb795863.jpg)
[Tang Laoshi] C -- encapsulation: classes and objects

Text classification still stays at Bert? The dual contrast learning framework is too strong

Problems encountered in installing mysql8 on MAC
[SQL Server fast track] - authentication and establishment and management of user accounts

05. Security of blog project

【云原生】3.1 Kubernetes平台安装KubeSpher
随机推荐
Rustdesk builds its own remote desktop relay server
Pytorch代码注意的细节,容易敲错的地方
C Advanced - data storage (Part 1)
Yygh-11-timing statistics
Notes, continuation, escape and other symbols
剑指 Offer II 039. 直方图最大矩形面积
Garbage collector with serial, throughput priority and response time priority
HAC cluster modifying administrator user password
无代码六月大事件|2022无代码探索者大会即将召开;AI增强型无代码工具推出...
Codeless June event 2022 codeless Explorer conference will be held soon; AI enhanced codeless tool launched
B站刘二大人-多元逻辑回归 Lecture 7
什么是独立IP,独立IP主机怎么样?
What impact will frequent job hopping have on your career?
Promise summary
Unity Vector3. Use and calculation principle of reflect
04. Project blog log
[SQL Server Express Way] - authentification et création et gestion de comptes utilisateurs
注释、接续、转义等符号
First acquaintance with CDN
[experience] install Visio on win11