当前位置:网站首页>Pytorch教程Introduction中的神经网络实现示例
Pytorch教程Introduction中的神经网络实现示例
2022-07-31 04:56:00 【KylinSchmidt】
Pytorch Turtorial Introduction中的神经网络示例代码,英文详细介绍参见https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html。
数据集为衣服等的灰度图片。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
# 超参数设置
learning_rate = 1e-3
batch_size = 64
epochs = 10
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 随机梯度下降算法
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {
loss:>7f} [{
current:>5d}/{
size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {
(100*correct):>0.1f}%, Avg loss: {
test_loss:>8f} \n")
for t in range(epochs):
print(f"Epoch {
t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
边栏推荐
- [Linear Neural Network] softmax regression
- MySQL database addition, deletion, modification and query (detailed explanation of basic operation commands)
- MySQL忘记密码怎么办
- Error EPERM operation not permitted, mkdir 'Dsoftwarenodejsnode_cache_cacach Two solutions
- Unity Framework Design Series: How Unity Designs Network Frameworks
- MySQL常见面试题汇总(建议收藏!!!)
- Gaussian distribution and its maximum likelihood estimation
- 【R语言】【3】apply,tapply,lapply,sapply,mapply与par函数相关参数
- ERROR 1819 (HY000) Your password does not satisfy the current policy requirements
- SOLVED: After accidentally uninstalling pip (two ways to manually install pip)
猜你喜欢

CentOS7 install MySQL graphic detailed tutorial

Explanation of

Centos7 install mysql5.7

CentOS7 —— yum安装mysql

centos7安装mysql5.7步骤(图解版)

MySQL优化之慢日志查询

Summary of MySQL common interview questions (recommended collection!!!)

STM32HAL库修改Hal_Delay为us级延时

WPF WPF 】 【 the depth resolution of the template

Sun Wenlong, Secretary General of the Open Atom Open Source Foundation |
随机推荐
Duplicate entry 'XXX' for key 'XXX.PRIMARY' solution.
MySQL事务(transaction) (有这篇就足够了..)
MySQL optimization: from ten seconds to three hundred milliseconds
visual studio 那些提高效率的快捷键,总结(不时更新)
ERROR 1819 (HY000) Your password does not satisfy the current policy requirements
MySQL fuzzy query can use INSTR instead of LIKE
Sql解析转换之JSqlParse完整介绍
数字经济时代的开源数据库创新 | 2022开放原子全球开源峰会数据库分论坛圆满召开
The input input box displays the precision of two decimal places
VScode+ESP32 quickly install ESP-IDF plugin
MySQL事务隔离级别详解
npm、nrm两种方式查看源和切换镜像
PCL calculates the point cloud coordinate maximum and its index
XSS shooting range (3) prompt to win
Simple read operation of EasyExcel
MySQL database must add, delete, search and modify operations (CRUD)
Unity教程:URP渲染管线实战教程系列【1】
DVWA靶场环境搭建
Industry-university-research application to build an open source talent ecosystem | 2022 Open Atom Global Open Source Summit Education Sub-Forum was successfully held
Summary of MySQL common interview questions (recommended collection!!!)