当前位置:网站首页>B站刘二大人-Softmx分类器及MNIST实现-Lecture 9
B站刘二大人-Softmx分类器及MNIST实现-Lecture 9
2022-07-06 05:33:00 【宁然也】
系列文章:
softmax分类器
损失函数:交叉熵
Numpty中实现交叉熵损失函数
Pytorch实现的交叉熵损失
MNIST实现
导包
import torch
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
# 使用relu()
import torch.nn.functional as F
# 构造优化器
import torch.optim as optim
1-准备数据
# 1-准备数据
batch_size = 64
# 将PIL图像抓换为Tensor
transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms,
download=False)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms,
download=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
shuffle=False)
2-设计网络模型
# 2-设计网络模型
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.lay1 = torch.nn.Linear(784,512)
self.lay2 = torch.nn.Linear(512,256)
self.lay3 = torch.nn.Linear(256,128)
self.lay4 = torch.nn.Linear(128,64)
self.lay5 = torch.nn.Linear(64,10)
def forward(self,x):
x = x.view(-1,784)
x = F.relu(self.lay1(x))
x = F.relu(self.lay2(x))
x = F.relu(self.lay3(x))
x = F.relu(self.lay4(x))
x = F.relu(self.lay5(x))
return x
3-构造模型、损失函数、优化器
# 3-构造损失函数与优化器
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005,momentum=0.5)
4-训练、测试
# 4-训练测试
def train(epoch):
running_loss = 0.0
# enumerate(train_loader, 0): batch_idx从0计数
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
optimizer.zero_grad()
# forward + backward + update
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if(batch_idx % 300 == 299):
print('[%d, %5d] loss: %.3f'%(epoch + 1, batch_idx + 1, running_loss/300))
running_loss = 0.0
def test():
correct = 0
total = 0
# 测试不需要生成计算图,不需要梯度更新、反向传播
with torch.no_grad():
# data是len =2的list
# input是data[0], target 是data[1]
for data in test_loader:
images, label = data
outputs = model(images)
# _ 是返回的最大值, predicted是最大值对应的下标
_, predicted = torch.max(outputs.data, dim=1)
total += label.size(0)
correct += (predicted == label).sum().item()
print('Accutacy on test set : %d %%'%(100*correct/total))
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
关于:(predicted == label).sum()
会将predicted中的每个元素与对应位置的label进行对边,相同返回True,不同返回False。 .sum求的True的个数
inputs, target = data 说明赋值情况
完整代码
import torch
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
# 使用relu()
import torch.nn.functional as F
# 构造优化器
import torch.optim as optim
# 1-准备数据
batch_size = 64
# 将PIL图像抓换为Tensor
transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms,
download=False)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms,
download=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
shuffle=False)
# 2-设计网络模型
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.lay1 = torch.nn.Linear(784,512)
self.lay2 = torch.nn.Linear(512,256)
self.lay3 = torch.nn.Linear(256,128)
self.lay4 = torch.nn.Linear(128,64)
self.lay5 = torch.nn.Linear(64,10)
def forward(self,x):
x = x.view(-1,784)
x = F.relu(self.lay1(x))
x = F.relu(self.lay2(x))
x = F.relu(self.lay3(x))
x = F.relu(self.lay4(x))
x = F.relu(self.lay5(x))
return x
# 3-构造损失函数与优化器
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005,momentum=0.5)
# 4-训练测试
def train(epoch):
running_loss = 0.0
# enumerate(train_loader, 0): batch_idx从0计数
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
optimizer.zero_grad()
# forward + backward + update
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if(batch_idx % 300 == 299):
print('[%d, %5d] loss: %.3f'%(epoch + 1, batch_idx + 1, running_loss/300))
running_loss = 0.0
def test():
correct = 0
total = 0
# 测试不需要生成计算图,不需要梯度更新、反向传播
with torch.no_grad():
for data in test_loader:
images, label = data
outputs = model(images)
# _ 是返回的最大值, predicted是最大值对应的下标
_, predicted = torch.max(outputs.data, dim=1)
total += label.size(0)
correct += (predicted == label).sum().item()
print('Accutacy on test set : %d %%'%(100*correct/total))
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
边栏推荐
- 【LeetCode】18、四数之和
- Vulhub vulnerability recurrence 67_ Supervisor
- 02. Develop data storage of blog project
- 26file filter anonymous inner class and lambda optimization
- Codeforces Round #804 (Div. 2) Editorial(A-B)
- C Advanced - data storage (Part 1)
- Promotion hung up! The leader said it wasn't my poor skills
- Unity gets the width and height of Sprite
- Select knowledge points of structure
- Codeforces Round #804 (Div. 2) Editorial(A-B)
猜你喜欢
CUDA11.1在线安装
Check the useful photo lossless magnification software on Apple computer
Promise summary
Installation de la Bibliothèque de processus PDK - csmc
Cuda11.1 online installation
59. Spiral matrix
指針經典筆試題
【torch】|torch. nn. utils. clip_ grad_ norm_
Self built DNS server, the client opens the web page slowly, the solution
大型网站如何选择比较好的云主机服务商?
随机推荐
【华为机试真题详解】统计射击比赛成绩
[untitled]
HAC集群修改管理员用户密码
Solution of QT TCP packet sticking
Yyds dry inventory SSH Remote Connection introduction
LeetCode_字符串反转_简单_557. 反转字符串中的单词 III
JDBC calls the stored procedure with call and reports an error
Algorithm -- climbing stairs (kotlin)
Redis消息队列
02. Develop data storage of blog project
How to download GB files from Google cloud hard disk
【经验】UltralSO制作启动盘时报错:磁盘/映像容量太小
Pix2pix: image to image conversion using conditional countermeasure networks
Fluent implements a loadingbutton with loading animation
应用安全系列之三十七:日志注入
Note the various data set acquisition methods of jvxetable
【torch】|torch.nn.utils.clip_grad_norm_
Using stopwatch to count code time
Problems encountered in installing mysql8 on MAC
Vulhub vulnerability recurrence 67_ Supervisor