当前位置:网站首页>Pytorch教程Introduction中的神经网络实现示例
Pytorch教程Introduction中的神经网络实现示例
2022-07-31 04:56:00 【KylinSchmidt】
Pytorch Turtorial Introduction中的神经网络示例代码,英文详细介绍参见https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html。
数据集为衣服等的灰度图片。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
# 超参数设置
learning_rate = 1e-3
batch_size = 64
epochs = 10
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 随机梯度下降算法
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {
loss:>7f} [{
current:>5d}/{
size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {
(100*correct):>0.1f}%, Avg loss: {
test_loss:>8f} \n")
for t in range(epochs):
print(f"Epoch {
t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
边栏推荐
- MySQL optimization: from ten seconds to three hundred milliseconds
- Duplicate entry 'XXX' for key 'XXX.PRIMARY' solution.
- MySQL优化:从十几秒优化到三百毫秒
- Unity手机游戏性能优化系列:针对CPU端的性能调优
- [py script] batch binarization processing images
- A complete introduction to JSqlParse of Sql parsing and conversion
- Unity URP渲染管线摄像机核心机制剖析
- 【R语言】【3】apply,tapply,lapply,sapply,mapply与par函数相关参数
- STM32HAL库修改Hal_Delay为us级延时
- Doris学习笔记之监控
猜你喜欢
Information System Project Manager Core Test Site (55) Configuration Manager (CMO) Work
mysql stored procedure
Unity资源管理系列:Unity 框架如何做好资源管理
Industry-university-research application to build an open source talent ecosystem | 2022 Open Atom Global Open Source Summit Education Sub-Forum was successfully held
Unity resources management series: Unity framework how to resource management
Lua,ILRuntime, HybridCLR(wolong)/huatuo hot update comparative analysis
Unity Tutorial: URP Rendering Pipeline Practical Tutorial Series [1]
input输入框展示两位小数之precision
Numpy中np.meshgrid的简单用法示例
View source and switch mirrors in two ways: npm and nrm
随机推荐
Industry landing presents new progress | 2022 OpenAtom Global Open Source Summit OpenAtom OpenHarmony sub-forum was successfully held
Unity Tutorial: URP Rendering Pipeline Practical Tutorial Series [1]
On Governance and Innovation | 2022 OpenAtom Global Open Source Summit OpenAnolis sub-forum was successfully held
ERROR 1064 (42000) You have an error in your SQL syntax; check the manual that corresponds to your
【debug锦集】Expected input batch_size (1) to match target batch_size (0)
SQL行列转换
MySQL optimization: from ten seconds to three hundred milliseconds
Unity手机游戏性能优化系列:针对CPU端的性能调优
Open Source Smart Future | 2022 OpenAtom Global Open Source Summit OpenAtom openEuler sub-forum was successfully held
[C language] Detailed explanation of operators
高斯分布及其极大似然估计
Simple read operation of EasyExcel
SOLVED: After accidentally uninstalling pip (two ways to manually install pip)
MySQL optimization slow log query
Blockbuster | foundation for platinum, gold, silver gave nameboards donors
扫雷小游戏——C语言
SQL statement to range query time field
What are the advantages and disadvantages of Unity shader forge and the built-in shader graph?
【py脚本】批量二值化处理图像
STM32HAL library modifies Hal_Delay to us-level delay