当前位置:网站首页>深度学习 pytorch cifar10数据集训练「建议收藏」
深度学习 pytorch cifar10数据集训练「建议收藏」
2022-06-25 15:35:00 【全栈程序员站长】
大家好,又见面了,我是你们的朋友全栈君。
1.加载数据集,并对数据集进行增强,类型转换 官网cifar10数据集 附链接:https://www.cs.toronto.edu/~kriz/cifar.html
读取数据过程中,可以改变batch_size和num_workers来加快训练速度
transform=transforms.Compose([
#图像增强
transforms.Resize(120),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(96),
transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
#转变为tensor 正则化
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
])
trainset=tv.datasets.CIFAR10(
root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
train=True,
download=True,
transform=transform
)
trainloader=data.DataLoader(
trainset,
batch_size=8,
shuffle=True, #乱序
num_workers=4,
)
testset=tv.datasets.CIFAR10(
root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
train=False,
download=True,
transform=transform
)
testloader=data.DataLoader(
testset,
batch_size=2,
shuffle=False,
num_workers=2
)net网络:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
self.max=nn.MaxPool2d(2,2)
self.q1=nn.Linear(16*441,120)
self.q2=nn.Linear(120,84)
self.q3=nn.Linear(84,10)
self.relu=nn.ReLU()
def forward(self,x):
x1=self.max(F.relu(self.conv1(x)))
x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
x3=x2.view(x2.size()[0],-1)
x4=F.relu(self.q1(x3))
x5=F.relu(self.q2(x4))
x6=self.q3(x5)
return x6训练模型
net=Net()
#损失函数
loss=nn.CrossEntropyLoss()
opt=optim.SGD(net.parameters(),lr=0.001)
for epoch in range(5):
running_loss=0.0
for i,data in enumerate(trainloader,0):
inputs,labels=data
inputs=inputs.cuda()
labels=labels.cuda()
inputs,labels=Variable(inputs),Variable(labels)
opt.zero_grad()
net.to(torch.device('cuda:0'))
h=net(inputs)
cost=loss(h,labels)
cost.backward()
opt.step()
running_loss+=cost.item()
if i%2000==1999:
print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))
running_loss=0.0
torch.save(net.state_dict(),r'net.pth')
correct=0
total=0
for data in testloader:
images,labels=data
optputs=net(Variable(images.cuda()))
_,predicted=torch.max(optputs.cpu(),1)
total+=labels.size(0)
correct+=(predicted==labels).sum()
print("准确率: %d %%" %(100*correct/total))接下来可以直接进行训练
在运行过程中会出现虚拟内存不够的情况,可以调整虚拟内存大小,解决这一问题。
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/152102.html原文链接:https://javaforall.cn
边栏推荐
- Programmer vs hacker thinking | daily anecdotes
- Differences and solutions of redis cache avalanche, cache penetration and cache breakdown
- Brief object memory layout
- What is OA
- 0703 interface automation - MySQL database connection, encapsulation, adding database verification in use cases
- Is it safe to open an account for new bonds? What preparations are needed
- 不要再「外包」AI 模型了!最新研究发现:有些破坏机器学习模型安全的「后门」无法被检测到
- How to convert a recorded DOM to a video file
- Report on Hezhou air32f103cbt6 development board
- Yolov5 Lite: fewer parameters, higher accuracy and faster detection speed
猜你喜欢

Sword finger offer 06 Print linked list from end to end

基于深度Q学习的雅达利打砖块游戏博弈

解决Visio和office365安装兼容问题

Arthas source code learning-1

Free books! AI across the Internet paints old photos. Here is a detailed tutorial!

JSON module dictionary and string conversion

Brief introduction to class loading process

Several relationships of UML
Cloning and importing DOM nodes

Differences and solutions of redis cache avalanche, cache penetration and cache breakdown
随机推荐
Sword finger offer 06 Print linked list from end to end
MySQL transaction characteristics and implementation principle
剑指 Offer 07. 重建二叉树
镁光256Gb NAND Flash芯片介绍
Detailed description of crontab command format and summary of common writing methods
Internal class learning notes
Client development (electron) system level API usage
Image segmentation based on deep learning: network structure design
Is it safe to open a stock account in Guoxin golden sun?
剑指 Offer 10- I. 斐波那契数列
Report on Hezhou air32f103cbt6 development board
What is the safest app for stock account opening? Tell me what you know
Differences between = = and = = = in JS (detailed explanation)
Learning to Measure Changes: Fully Convolutional Siamese Metric Networks for Scene Change Detection
Simulating Sir disease transmission model with netlogo
Source code analysis of nine routing strategies for distributed task scheduling platform XXL job
通过客户经理的开户链接开股票账户安全吗?
Several solutions to the distributed lock problem in partial Internet companies
Lombok common notes
Talk about the creation process of JVM objects