当前位置:网站首页>Pytorch教程Introduction中的神经网络实现示例
Pytorch教程Introduction中的神经网络实现示例
2022-07-31 04:56:00 【KylinSchmidt】
Pytorch Turtorial Introduction中的神经网络示例代码,英文详细介绍参见https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html。
数据集为衣服等的灰度图片。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
# 超参数设置
learning_rate = 1e-3
batch_size = 64
epochs = 10
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 随机梯度下降算法
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {
loss:>7f} [{
current:>5d}/{
size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {
(100*correct):>0.1f}%, Avg loss: {
test_loss:>8f} \n")
for t in range(epochs):
print(f"Epoch {
t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
边栏推荐
- EasyExcel的简单读取操作
- .NET-6.WinForm2.NanUI learning and summary
- Error EPERM operation not permitted, mkdir 'Dsoftwarenodejsnode_cache_cacach Two solutions
- Solved (the latest version of selenium framework element positioning error) NameError: name 'By' is not defined
- MySQL忘记密码怎么办
- Minesweeper game (written in c language)
- SQL语句中对时间字段进行区间查询
- DVWA靶场环境搭建
- Unity手机游戏性能优化系列:针对CPU端的性能调优
- 行业落地呈现新进展 | 2022开放原子全球开源峰会OpenAtom OpenHarmony分论坛圆满召开
猜你喜欢

Unity URP渲染管线摄像机核心机制剖析

Simple read operation of EasyExcel

Hand in hand to realize the picture preview plug-in (3)

矩池云快速安装torch-sparse、torch-geometric等包

Lua,ILRuntime, HybridCLR(wolong)/huatuo热更新对比分析

MySQL window function

View source and switch mirrors in two ways: npm and nrm

SQL injection of DVWA

On-line monitoring system for urban waterlogging and water accumulation in bridges and tunnels

益智类游戏关卡设计:逆推法--巧解益智类游戏关卡设计
随机推荐
打造基于ILRuntime热更新的组件化开发
sql statement - how to query data in another table based on the data in one table
mysql uses on duplicate key update to update data in batches
开源汇智创未来 | 2022开放原子全球开源峰会OpenAtom openEuler分论坛圆满召开
Lua,ILRuntime, HybridCLR(wolong)/huatuo hot update comparative analysis
数字经济时代的开源数据库创新 | 2022开放原子全球开源峰会数据库分论坛圆满召开
Unity教程:URP渲染管线实战教程系列【1】
Heavyweight | The Open Atomic School Source Line activity was officially launched
STM32 - DMA
Unity resources management series: Unity framework how to resource management
重磅 | 开放原子校源行活动正式启动
sql语句之多表查询
MySQL database backup
Hand in hand to realize the picture preview plug-in (3)
ABC D - Distinct Trio(k元组的个数
SQL statement to range query time field
input输入框展示两位小数之precision
30 Years of Open Source Community | 2022 Open Atom Global Open Source Summit 30 Years of Open Source Community Special Event Held Successfully
MySQL优化之慢日志查询
专访 | 阿里巴巴首席技术官程立:云+开源共同形成数字世界的可信基础