当前位置:网站首页>Pytorch教程Introduction中的神经网络实现示例
Pytorch教程Introduction中的神经网络实现示例
2022-07-31 04:56:00 【KylinSchmidt】
Pytorch Turtorial Introduction中的神经网络示例代码,英文详细介绍参见https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html。
数据集为衣服等的灰度图片。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
# 超参数设置
learning_rate = 1e-3
batch_size = 64
epochs = 10
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 随机梯度下降算法
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {
loss:>7f} [{
current:>5d}/{
size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {
(100*correct):>0.1f}%, Avg loss: {
test_loss:>8f} \n")
for t in range(epochs):
print(f"Epoch {
t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
边栏推荐
- mysql uses on duplicate key update to update data in batches
- unity2d game
- 打造基于ILRuntime热更新的组件化开发
- MySQL fuzzy query can use INSTR instead of LIKE
- Heavyweight | The Open Atomic School Source Line activity was officially launched
- Unity Fighter
- 110 MySQL interview questions and answers (continuously updated)
- MySQL开窗函数
- 【debug锦集】Expected input batch_size (1) to match target batch_size (0)
- 开源汇智创未来 | 2022开放原子全球开源峰会OpenAtom openEuler分论坛圆满召开
猜你喜欢
sql语句-如何以一个表中的数据为条件据查询另一个表中的数据
【线性神经网络】softmax回归
12个MySQL慢查询的原因分析
SOLVED: After accidentally uninstalling pip (two ways to manually install pip)
Unity教程:URP渲染管线实战教程系列【1】
Industry-university-research application to build an open source talent ecosystem | 2022 Open Atom Global Open Source Summit Education Sub-Forum was successfully held
MySQL transaction (transaction) (this is enough..)
ERROR 2003 (HY000) Can't connect to MySQL server on 'localhost3306' (10061)
矩池云快速安装torch-sparse、torch-geometric等包
XSS靶场(三)prompt to win
随机推荐
Summary of MySQL common interview questions (recommended collection!!!)
城市内涝及桥洞隧道积水在线监测系统
【C语言】操作符详解
数字经济时代的开源数据库创新 | 2022开放原子全球开源峰会数据库分论坛圆满召开
SQL statement to range query time field
Doris学习笔记之监控
MySQL database addition, deletion, modification and query (detailed explanation of basic operation commands)
STM32HAL库修改Hal_Delay为us级延时
ERROR 1064 (42000) You have an error in your SQL syntax; check the manual that corresponds to your
SOLVED: After accidentally uninstalling pip (two ways to manually install pip)
Hand in hand to realize the picture preview plug-in (3)
Mysql应用安装后找不到my.ini文件
ES 源码 API调用链路源码分析
XSS靶场(三)prompt to win
Puzzle Game Level Design: Reverse Method--Explaining Puzzle Game Level Design
Go language study notes - dealing with timeout problems - Context usage | Go language from scratch
Create componentized development based on ILRuntime hot update
SQL语句中对时间字段进行区间查询
论治理与创新 | 2022开放原子全球开源峰会OpenAnolis分论坛圆满召开
重磅 | 开放原子校源行活动正式启动