当前位置:网站首页>PyTorch 代码模板 (CNN)
PyTorch 代码模板 (CNN)
2022-07-25 09:27:00 【Haulyn5】
前言
这是一篇自用的 PyTorch 代码模板,将模型和数据的相关代码进行替换就可以训练新的模型和数据,因为自己有时候需要测试写一些代码,但是从头写又记不住,直接在这里放一份模板,用的时候改一改就好了。
主要参考资料:
第一个是一个叫做 Aladdin Persson 的 youtuber,视频内容就是带着你手敲代码,受益匪浅,这里的代码主要也源自视频中的代码。https://www.youtube.com/watch?v=wnK3uWv_WkU&t=269s
https://www.youtube.com/watch?v=wnK3uWv_WkU&t=269sQuickstart — PyTorch Tutorials 1.11.0+cu102 documentation
https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
正文
在 Colab上做了测试,可以正常运行。
torch.__version__ : 1.10.0+cu111
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 模型代码编写,forward 就是直接调用 model(x) 时执行的计算流程
class CNN(nn.Module):
def __init__(self, in_channels=1, num_classes=10):
super(CNN,self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=(3,3), stride=(1,1), padding=(1,1))
self.pool = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))
self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3,3), stride=(1,1), padding=(1,1))
self.fc1 = nn.Linear(16*7*7, num_classes)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = x.reshape(x.shape[0], -1)
x = self.fc1(x)
return x
# 测试能否使用 GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 超参数设定
input_size = 784
in_channel = 1
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 5
# 读取数据集
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
# 实例化模型
model = CNN().to(device)
# 设定损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 下面部分是训练,有时可以单独拿出来写函数
num_epochs = 4
for epoch in range(num_epochs):
for batch_idex, (data, targets) in enumerate(train_loader):
# 如果模型在 GPU 上,数据也要读入 GPU
data = data.to(device=device)
targets = targets.to(device=device)
# print(data.shape) # [64,1,28,28] Batch 大小 64 , 1 channel, 28*28 像素
# forward 前向模型计算输出,然后根据输出算损失
scores = model(data)
loss = criterion(scores, targets)
# backward 反向传播计算梯度
optimizer.zero_grad()
loss.backward()
# 梯度下降,优化参数
optimizer.step()
# 评估准确度的函数
def check_accuracy(loader, model):
if loader.dataset.train:
print("Checking acc on training data")
else:
print("Checking acc on testing data")
num_correct = 0
num_samples = 0
model.eval() # 将模型调整为 eval 模式,具体可以搜一下区别
with torch.no_grad():
for x, y in loader:
x = x.to(device=device)
y = y.to(device=device)
scores = model(x)
# 64*10
_, predictions = scores.max(1)
num_correct += (predictions == y).sum()
num_samples += predictions.size(0)
acc = float(num_correct)/float(num_samples)*100
print(f'Got {num_correct} / {num_samples} with accuracy {acc:.2f}')
model.train()
return acc
check_accuracy(train_loader, model)
check_accuracy(test_loader, model) 后记
其实在 GitHub 看到了另一个工程化的模板,将不同的部分都用一个单独的 python 脚本文件分离了,有机会好好研究吧。
边栏推荐
- Subtotal of rospy odometry sinkhole
- NPM details
- UE4 LoadingScreen动态加载启动动画
- 多线程——Runnable接口,龟兔赛跑
- JSP详解
- 力扣刷题组合问题总结(回溯)
- [necessary for growth] Why do I recommend you to write a blog? May you be what you want to be in years to come.
- 多线程——死锁和synchronized
- Pytorch 张量列表转换为张量 List of Tensor to Tensor 使用 torch.stack()
- 多线程——Callable接口,lambda
猜你喜欢

Probabilistic robot learning notes Chapter 2

【专栏】RPC系列(理论)-夜的第一章
![[nearly 10000 words dry goods] don't let your resume don't match your talent -- teach you to make the most suitable resume by hand](/img/2d/e3a326175f04826b9d9c96baedc3a5.png)
[nearly 10000 words dry goods] don't let your resume don't match your talent -- teach you to make the most suitable resume by hand

OSPF协议的配置(以华为eNSP为例)

Probability theory and mathematical statistics 4 continuous random variables and probability distributions (Part 1)
![[necessary for growth] Why do I recommend you to write a blog? May you be what you want to be in years to come.](/img/f5/e6739083f0dce8da1d09d078321633.png)
[necessary for growth] Why do I recommend you to write a blog? May you be what you want to be in years to come.

Subtotal of rospy odometry sinkhole
![[recommended collection] with these learning methods, I joined the world's top 500 - the](/img/95/e34473a1628521d4b07e56877fcff1.png)
[recommended collection] with these learning methods, I joined the world's top 500 - the "fantastic skills and extravagance" in the Internet age

nodejs链接mysql报错:ER_NOT_SUPPORTED_AUTH_MODEError: ER_NOT_SUPPORTED_AUTH_MODE

Detailed explanation of MySQL database
随机推荐
Swing组件
Introduction to armv8 architecture
Qt 6.2的下载和安装
@Import, conditional and @importresource annotations
CentOs安装redis
拷贝过来老的项目变成web项目
NPM details
关于slf4j log4j log4j2的jar包配合使用的那些事
T5 paper summary
[nearly 10000 words dry goods] don't let your resume don't match your talent -- teach you to make the most suitable resume by hand
ROS distributed operation -- launch file starts nodes on multiple machines
数论--约数研究
message from server: “Host ‘xxx.xxx.xxx.xxx‘ is not allowed to connect to this MySQL server“
mysql历史数据补充新数据
TCP传输
Detailed explanation of JDBC operation database
腾讯云之错误[100007] this env is not enable anonymous login
Probability theory and mathematical statistics 4 continuous random variables and probability distributions (Part 1)
Redux使用和剖析
Summary of most consistency problems