当前位置:网站首页>14. Example - Multi classification problem
14. Example - Multi classification problem
2022-07-27 05:59:00 【Pie star's favorite spongebob】
Using cross entropy loss To optimize multi classification problems
Network Architecture
The output is 10 layer , Represents the 10 classification .
Because I haven't learned the knowledge of linear layer , So some low-level operations are used to replace .
Create three new linear layers , Each linear layer has w and btensor
Pay attention to pytorch in , The first dimension is out, The second dimension is in
w1,b1 = torch.randn(200,784,requires_grad=True),torch.zeros(200,requires_grad=True)
w2,b2 = torch.randn(200,200,requires_grad=True),torch.zeros(200,requires_grad=True)
w3,b3 = torch.randn(10,200,requires_grad=True),torch.zeros(10,requires_grad=True)
The first linear layer can be understood as 784(28×28) Reduce dimension into 200. We must designate requires_grad by true, Otherwise, an error will be reported .
The second hidden layer is from 200 To 200, It's just one. feature The process of transformation , No dimensionality reduction .
The third output layer , The final output 10 Classes .
def forward(x):
x = x @ w1.t() + b1
x = F.relu(x)
x = x @ w2.t() + b2
x = F.relu(x)
x = x @ w3.t() + b3
x = F.relu(x)
return x
The last layer can also be left unused relu, Out of commission sigmod perhaps softmax, Because it will be used later softmax
This is the network tensor and forward The process ,
Train
Next, define an optimizer , The goal of optimization is 3 Set variables of the whole connection layer [w1,b1,w2,b2,w3,b3]
optimizer=optim.SGD([w1,b1,w2,b2,w3,b3],lr=learning_rate)
criteon=nn.CrossEntropyLoss()
crossEntropyLoss and F.crossEntropyLoss Function as , All are softmax+log+nll_loss
for epoch in range(epochs):
for batch_idx,(data,target) in enumerate(train_loader):
data=data.view(-1,28*28)
logits=forward(data)
loss=criteon(logits,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train epoch:{}[{}/{} ({:.0f}%)]\tLoss:{:.6f}'.format(
epoch,batch_idx*len(data),len(train_loader.dataset),
100.*batch_idx/len(train_loader),loss.item()))
step It means a batch, and eopch Is the entire data set .

loss keep 10% unchanged , Because of initialization problems
When we are right w1,w2,w3 After the initialization ,b It's direct torch.zeros Initialized .
torch.nn.init.kaiming_normal(w1)
torch.nn.init.kaiming_normal(w2)
torch.nn.init.kaiming_normal(w3)

Complete code
import torch
import torch.nn.functional as F
from torch import optim
from torch import nn
import torchvision
batch_size = 200
learning_rate=0.01
epochs=10
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
# hold numpy Format to tensor
# Regularization , stay 0 near , Can improve performance
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=False)
w1,b1 = torch.randn(200,784,requires_grad=True),torch.zeros(200,requires_grad=True)
w2,b2 = torch.randn(200,200,requires_grad=True),torch.zeros(200,requires_grad=True)
w3,b3 = torch.randn(10,200,requires_grad=True),torch.zeros(10,requires_grad=True)
torch.nn.init.kaiming_normal(w1)
torch.nn.init.kaiming_normal(w2)
torch.nn.init.kaiming_normal(w3)
def forward(x):
x = x @ w1.t() + b1
x = F.relu(x)
x = x @ w2.t() + b2
x = F.relu(x)
x = x @ w3.t() + b3
x = F.relu(x)
return x
optimizer=optim.SGD([w1,b1,w2,b2,w3,b3],lr=learning_rate)
criteon=nn.CrossEntropyLoss()
for epoch in range(epochs):
for batch_idx,(data,target) in enumerate(train_loader):
data=data.view(-1,28*28)
logits=forward(data)
loss=criteon(logits,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train epoch:{}[{}/{} ({:.0f}%)]\tLoss:{:.6f}'.format(
epoch,batch_idx*len(data),len(train_loader.dataset),
100.*batch_idx/len(train_loader),loss.item()))
test_loss=0
correct=0
for data,target in test_loader:
data=data.view(-1,28*28)
logits=forward(data)
test_loss+=criteon(logits,target).item()
pred =logits.data.max(1)[1]
correct+=pred.eq(target.data).sum()
test_loss/=len(test_loader.dataset)
print('\nTest set:Average loss:{:.4f},Accuracy:{}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
边栏推荐
- GBASE 8C——SQL参考6 sql语法(1)
- Okaleido launched the fusion mining mode, which is the only way for Oka to verify the current output
- Social media user level psychological stress detection based on deep neural network
- 12.优化问题实战
- 2. Simple regression problem
- Day10. Work organization and mental health problems in PhD students
- Day 4.Social Data Sentiment Analysis: Detection of Adolescent Depression Signals
- Docker deploys the stand-alone version of redis - modify the redis password and persistence method
- golang控制goroutine数量以及获取处理结果
- golang封装mysql涉及到的包以及sqlx和gorm的区别
猜你喜欢

Digital image processing Chapter 8 - image compression

Day 9. Graduate survey: A love–hurt relationship

6.维度变换和Broadcasting

Read and understand the advantages of the LAAS scheme of elephant swap

西瓜书学习笔记---第四章 决策树

我想不通,MySQL 为什么使用 B+ 树来作索引?

15.GPU加速、minist测试实战和visdom可视化

18.卷积神经网络

1. Introduction to pytorch

MySQL如何执行查询语句
随机推荐
GBASE 8C——SQL参考6 sql语法(11)
Read and understand the advantages of the LAAS scheme of elephant swap
基于深度神经网络的社交媒体用户级心理压力检测
mysql优化sql相关(持续补充)
Day 9. Graduate survey: A love–hurt relationship
go通过channel获取goroutine的处理结果
GBASE 8C——SQL参考6 sql语法(4)
根据文本自动生成UML时序图(draw.io格式)
Global evidence of expressed sentimental alterations during the covid-19 pandemics
Gbase 8C - SQL reference 5 full text search
MySQL快速比较数据库表数据
Day14. Using interpretable machine learning method to distinguish intestinal tuberculosis and Crohn's disease
Uboot supports LCD and HDMI to display different logo images
Digital image processing -- Chapter 9 morphological image processing
PHP的CI框架学习
MySQL如何执行查询语句
Day 9. Graduate survey: A love–hurt relationship
关于pytorch反向传播的思考
一张照片攻破人脸识别系统:能点头摇头张嘴,网友
Gbase 8C - SQL reference 6 SQL syntax (1)