当前位置:网站首页>Pytorch教程Introduction中的神经网络实现示例
Pytorch教程Introduction中的神经网络实现示例
2022-07-31 04:56:00 【KylinSchmidt】
Pytorch Turtorial Introduction中的神经网络示例代码,英文详细介绍参见https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html。
数据集为衣服等的灰度图片。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
# 超参数设置
learning_rate = 1e-3
batch_size = 64
epochs = 10
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 随机梯度下降算法
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {
loss:>7f} [{
current:>5d}/{
size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {
(100*correct):>0.1f}%, Avg loss: {
test_loss:>8f} \n")
for t in range(epochs):
print(f"Epoch {
t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
边栏推荐
- Heavyweight | The Open Atomic School Source Line activity was officially launched
- 从零开始,一镜到底,纯净系统搭建除草机(Grasscutter)
- [Detailed explanation of ORACLE Explain]
- sql语句-如何以一个表中的数据为条件据查询另一个表中的数据
- ERP Production Operation Control Kingdee
- DVWA靶场环境搭建
- Two address pools r2 are responsible for managing the address pool r1 is responsible for managing dhcp relays
- Unity手机游戏性能优化系列:针对CPU端的性能调优
- STM32——DMA
- [Linear Neural Network] softmax regression
猜你喜欢

sql语句-如何以一个表中的数据为条件据查询另一个表中的数据

Lua,ILRuntime, HybridCLR(wolong)/huatuo hot update comparative analysis

On Governance and Innovation | 2022 OpenAtom Global Open Source Summit OpenAnolis sub-forum was successfully held

centos7安装mysql5.7

MySQL开窗函数

DVWA shooting range environment construction

A complete introduction to JSqlParse of Sql parsing and conversion

【C语言】操作符详解

PWN ROP

DVWA installation tutorial (understand what you don't understand · in detail)
随机推荐
Lua,ILRuntime, HybridCLR(wolong)/huatuo热更新对比分析
Heavyweight | The Open Atomic School Source Line activity was officially launched
参考代码系列_1.各种语言的Hello World
数字经济时代的开源数据库创新 | 2022开放原子全球开源峰会数据库分论坛圆满召开
mysql使用on duplicate key update批量更新数据
【R语言】【3】apply,tapply,lapply,sapply,mapply与par函数相关参数
Multiple table query of sql statement
mysql stored procedure
Go语学习笔记 - 处理超时问题 - Context使用 | 从零开始Go语言
Minesweeper game - C language
SQL injection of DVWA
Minesweeper game (written in c language)
Unity教程:URP渲染管线实战教程系列【1】
XSS靶场(三)prompt to win
Solved (the latest version of selenium framework element positioning error) NameError: name 'By' is not defined
110道 MySQL面试题及答案 (持续更新)
mysql uses on duplicate key update to update data in batches
Centos7 install mysql5.7
[Detailed explanation of ORACLE Explain]
Unity资源管理系列:Unity 框架如何做好资源管理