当前位置:网站首页>Pytorch教程Introduction中的神经网络实现示例
Pytorch教程Introduction中的神经网络实现示例
2022-07-31 04:56:00 【KylinSchmidt】
Pytorch Turtorial Introduction中的神经网络示例代码,英文详细介绍参见https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html。
数据集为衣服等的灰度图片。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
# 超参数设置
learning_rate = 1e-3
batch_size = 64
epochs = 10
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 随机梯度下降算法
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {
loss:>7f} [{
current:>5d}/{
size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {
(100*correct):>0.1f}%, Avg loss: {
test_loss:>8f} \n")
for t in range(epochs):
print(f"Epoch {
t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
边栏推荐
- sql语句-如何以一个表中的数据为条件据查询另一个表中的数据
- MySQL事务(transaction) (有这篇就足够了..)
- 开源汇智创未来 | 2022开放原子全球开源峰会OpenAtom openEuler分论坛圆满召开
- Lua,ILRuntime, HybridCLR(wolong)/huatuo热更新对比分析
- Unity shader forge和自带的shader graph,有哪些优缺点?
- ERROR 2003 (HY000) Can't connect to MySQL server on 'localhost3306' (10061)
- Lua,ILRuntime, HybridCLR(wolong)/huatuo hot update comparative analysis
- 扫雷小游戏——C语言
- 论治理与创新 | 2022开放原子全球开源峰会OpenAnolis分论坛圆满召开
- Unity框架设计系列:Unity 如何设计网络框架
猜你喜欢

npm、nrm两种方式查看源和切换镜像

ENSP, VLAN division, static routing, comprehensive configuration of Layer 3 switches

HCIP第十天_BGP路由汇总实验

On-line monitoring system for urban waterlogging and water accumulation in bridges and tunnels

From scratch, a mirror to the end, a pure system builds a grasscutter (Grasscutter)

MySQL常见面试题汇总(建议收藏!!!)

ES source code API call link source code analysis

Puzzle Game Level Design: Reverse Method--Explaining Puzzle Game Level Design

MySQL开窗函数

prompt.ml/15中<svg>标签使用解释
随机推荐
Interview | Cheng Li, CTO of Alibaba: Cloud + open source together form a credible foundation for the digital world
CentOS7 —— yum安装mysql
MySQL优化之慢日志查询
.NET-9.乱七八糟的理论笔记(概念,思想)
Three oj questions on leetcode
Centos7 install mysql5.7
Lua,ILRuntime, HybridCLR(wolong)/huatuo hot update comparative analysis
SQL行列转换
Industry-university-research application to build an open source talent ecosystem | 2022 Open Atom Global Open Source Summit Education Sub-Forum was successfully held
Minesweeper game (written in c language)
MySQL optimization slow log query
MySQL事务(transaction) (有这篇就足够了..)
SQL row-column conversion
开源汇智创未来 | 2022开放原子全球开源峰会OpenAtom openEuler分论坛圆满召开
sql语句之多表查询
The MySQL database installed configuration nanny level tutorial for 8.0.29 (for example) have hands
益智类游戏关卡设计:逆推法--巧解益智类游戏关卡设计
Musk talks to the "virtual version" of Musk, how far is the brain-computer interaction technology from us
CentOS7 安装MySQL 图文详细教程
SQL语句中对时间字段进行区间查询