当前位置:网站首页>14. Example - Multi classification problem
14. Example - Multi classification problem
2022-07-27 05:59:00 【Pie star's favorite spongebob】
Using cross entropy loss To optimize multi classification problems
Network Architecture
The output is 10 layer , Represents the 10 classification .
Because I haven't learned the knowledge of linear layer , So some low-level operations are used to replace .
Create three new linear layers , Each linear layer has w and btensor
Pay attention to pytorch in , The first dimension is out, The second dimension is in
w1,b1 = torch.randn(200,784,requires_grad=True),torch.zeros(200,requires_grad=True)
w2,b2 = torch.randn(200,200,requires_grad=True),torch.zeros(200,requires_grad=True)
w3,b3 = torch.randn(10,200,requires_grad=True),torch.zeros(10,requires_grad=True)
The first linear layer can be understood as 784(28×28) Reduce dimension into 200. We must designate requires_grad by true, Otherwise, an error will be reported .
The second hidden layer is from 200 To 200, It's just one. feature The process of transformation , No dimensionality reduction .
The third output layer , The final output 10 Classes .
def forward(x):
x = x @ w1.t() + b1
x = F.relu(x)
x = x @ w2.t() + b2
x = F.relu(x)
x = x @ w3.t() + b3
x = F.relu(x)
return x
The last layer can also be left unused relu, Out of commission sigmod perhaps softmax, Because it will be used later softmax
This is the network tensor and forward The process ,
Train
Next, define an optimizer , The goal of optimization is 3 Set variables of the whole connection layer [w1,b1,w2,b2,w3,b3]
optimizer=optim.SGD([w1,b1,w2,b2,w3,b3],lr=learning_rate)
criteon=nn.CrossEntropyLoss()
crossEntropyLoss and F.crossEntropyLoss Function as , All are softmax+log+nll_loss
for epoch in range(epochs):
for batch_idx,(data,target) in enumerate(train_loader):
data=data.view(-1,28*28)
logits=forward(data)
loss=criteon(logits,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train epoch:{}[{}/{} ({:.0f}%)]\tLoss:{:.6f}'.format(
epoch,batch_idx*len(data),len(train_loader.dataset),
100.*batch_idx/len(train_loader),loss.item()))
step It means a batch, and eopch Is the entire data set .

loss keep 10% unchanged , Because of initialization problems
When we are right w1,w2,w3 After the initialization ,b It's direct torch.zeros Initialized .
torch.nn.init.kaiming_normal(w1)
torch.nn.init.kaiming_normal(w2)
torch.nn.init.kaiming_normal(w3)

Complete code
import torch
import torch.nn.functional as F
from torch import optim
from torch import nn
import torchvision
batch_size = 200
learning_rate=0.01
epochs=10
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
# hold numpy Format to tensor
# Regularization , stay 0 near , Can improve performance
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=False)
w1,b1 = torch.randn(200,784,requires_grad=True),torch.zeros(200,requires_grad=True)
w2,b2 = torch.randn(200,200,requires_grad=True),torch.zeros(200,requires_grad=True)
w3,b3 = torch.randn(10,200,requires_grad=True),torch.zeros(10,requires_grad=True)
torch.nn.init.kaiming_normal(w1)
torch.nn.init.kaiming_normal(w2)
torch.nn.init.kaiming_normal(w3)
def forward(x):
x = x @ w1.t() + b1
x = F.relu(x)
x = x @ w2.t() + b2
x = F.relu(x)
x = x @ w3.t() + b3
x = F.relu(x)
return x
optimizer=optim.SGD([w1,b1,w2,b2,w3,b3],lr=learning_rate)
criteon=nn.CrossEntropyLoss()
for epoch in range(epochs):
for batch_idx,(data,target) in enumerate(train_loader):
data=data.view(-1,28*28)
logits=forward(data)
loss=criteon(logits,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train epoch:{}[{}/{} ({:.0f}%)]\tLoss:{:.6f}'.format(
epoch,batch_idx*len(data),len(train_loader.dataset),
100.*batch_idx/len(train_loader),loss.item()))
test_loss=0
correct=0
for data,target in test_loader:
data=data.view(-1,28*28)
logits=forward(data)
test_loss+=criteon(logits,target).item()
pred =logits.data.max(1)[1]
correct+=pred.eq(target.data).sum()
test_loss/=len(test_loader.dataset)
print('\nTest set:Average loss:{:.4f},Accuracy:{}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
边栏推荐
- 数字图像处理 第二章 数字图像基础
- Day 11. Evidence for a mental health crisis in graduate education
- 9.高阶操作
- 2020年PHP中级面试知识点及答案
- Day 6. Analysis of the energy transmission process of network public opinion in major medical injury events * -- Taking the "Wei Zexi incident" as an example
- 13.逻辑回归
- 西瓜书学习笔记---第一、二章
- Gbase 8C - SQL reference 6 SQL syntax (6)
- Emoji表情符号用于文本情感分析-Improving sentiment analysis accuracy with emoji embedding
- GBASE 8C——SQL参考6 sql语法(9)
猜你喜欢

15.GPU加速、minist测试实战和visdom可视化

MySQL如何执行查询语句

Seektiger's okaleido has a big move. Will the STI of ecological pass break out?

3. Classification problems - initial experience of handwritten digit recognition

4.张量数据类型和创建Tensor

Day 3. Suicidal ideation and behavior in institutions of higher learning: A latent class analysis

If the interviewer asks you about JVM, the extra answer of "escape analysis" technology will give you extra points

10.梯度、激活函数和loss

7.合并与分割

Jenkins build image automatic deployment
随机推荐
GBASE 8C——SQL参考6 sql语法(15)
Do you really know session and cookies?
vim编辑器全部删除文件内容
MySQL索引优化相关原理
一张照片攻破人脸识别系统:能点头摇头张嘴,网友
golang封装mysql涉及到的包以及sqlx和gorm的区别
MySQL快速比较数据库表数据
Digital image processing Chapter 4 - frequency domain filtering
视觉横向课题bug1:FileNotFoundError: Could not find module ‘MvCameraControl.dll‘ (or one of it
数字图像处理——第三章 灰度变换与空间滤波
舆情&传染病时空分析文献阅读笔记
golang控制goroutine数量以及获取处理结果
Day 7. Towards Preemptive Detection of Depression and Anxiety in Twitter
mysql优化sql相关(持续补充)
数字图像处理第五章——图像复原与重建
Graph node deployment
GBASE 8C——SQL参考6 sql语法(5)
Day10. Work organization and mental health problems in PhD students
4.张量数据类型和创建Tensor
rk3399 gpio口 如何查找是哪个gpio口