当前位置:网站首页>深度学习 pytorch cifar10数据集训练「建议收藏」
深度学习 pytorch cifar10数据集训练「建议收藏」
2022-06-25 15:35:00 【全栈程序员站长】
大家好,又见面了,我是你们的朋友全栈君。
1.加载数据集,并对数据集进行增强,类型转换 官网cifar10数据集 附链接:https://www.cs.toronto.edu/~kriz/cifar.html
读取数据过程中,可以改变batch_size和num_workers来加快训练速度
transform=transforms.Compose([
#图像增强
transforms.Resize(120),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(96),
transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
#转变为tensor 正则化
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
])
trainset=tv.datasets.CIFAR10(
root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
train=True,
download=True,
transform=transform
)
trainloader=data.DataLoader(
trainset,
batch_size=8,
shuffle=True, #乱序
num_workers=4,
)
testset=tv.datasets.CIFAR10(
root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
train=False,
download=True,
transform=transform
)
testloader=data.DataLoader(
testset,
batch_size=2,
shuffle=False,
num_workers=2
)net网络:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
self.max=nn.MaxPool2d(2,2)
self.q1=nn.Linear(16*441,120)
self.q2=nn.Linear(120,84)
self.q3=nn.Linear(84,10)
self.relu=nn.ReLU()
def forward(self,x):
x1=self.max(F.relu(self.conv1(x)))
x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
x3=x2.view(x2.size()[0],-1)
x4=F.relu(self.q1(x3))
x5=F.relu(self.q2(x4))
x6=self.q3(x5)
return x6训练模型
net=Net()
#损失函数
loss=nn.CrossEntropyLoss()
opt=optim.SGD(net.parameters(),lr=0.001)
for epoch in range(5):
running_loss=0.0
for i,data in enumerate(trainloader,0):
inputs,labels=data
inputs=inputs.cuda()
labels=labels.cuda()
inputs,labels=Variable(inputs),Variable(labels)
opt.zero_grad()
net.to(torch.device('cuda:0'))
h=net(inputs)
cost=loss(h,labels)
cost.backward()
opt.step()
running_loss+=cost.item()
if i%2000==1999:
print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))
running_loss=0.0
torch.save(net.state_dict(),r'net.pth')
correct=0
total=0
for data in testloader:
images,labels=data
optputs=net(Variable(images.cuda()))
_,predicted=torch.max(optputs.cpu(),1)
total+=labels.size(0)
correct+=(predicted==labels).sum()
print("准确率: %d %%" %(100*correct/total))接下来可以直接进行训练
在运行过程中会出现虚拟内存不够的情况,可以调整虚拟内存大小,解决这一问题。
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/152102.html原文链接:https://javaforall.cn
边栏推荐
- Simulating Sir disease transmission model with netlogo
- Client development (electron) data store
- Rapport de la main - d'oeuvre du Conseil de développement de l'aecg air32f103cbt6
- 剑指 Offer II 091. 粉刷房子
- Is Guoxin golden sun reliable? Is it legal? Is it safe to open a stock account?
- Open the box to experience rust, come on!!!
- Detailed description of crontab command format and summary of common writing methods
- Built in methods for data types
- 合宙Air32F103CBT6開發板上手報告
- Mark the same items in the Li list in red
猜你喜欢

Sword finger offer II 091 Paint the house
After the project is pushed to the remote warehouse, Baota webhook automatically publishes it

解析数仓lazyagg查询重写优化

Highly concurrent optimized Lua + openresty+redis +mysql (multi-level cache implementation) + current limit +canal synchronization solution
Desktop development (Tauri) opens the first chapter
Client development (electron) system level API usage

Brief object memory layout
Client development (electron) system level API usage 2

剑指 Offer II 091. 粉刷房子

不要再「外包」AI 模型了!最新研究发现:有些破坏机器学习模型安全的「后门」无法被检测到
随机推荐
How to convert a recorded DOM to a video file
Yadali brick playing game based on deep Q-learning
JS中的==和===的区别(详解)
Brief object memory layout
Talk about the creation process of JVM objects
剑指 Offer 04. 二维数组中的查找
MySQL modifier l'instruction de champ
Day01: learning notes
基于深度Q学习的雅达利打砖块游戏博弈
到底要不要去外包公司?这篇带你全面了解外包那些坑!
Client development (electron) system level API usage
Multithreading, parallelism, concurrency, thread safety
Detailed summary of reasons why alertmanager fails to send alarm messages at specified intervals / irregularly
Sword finger offer 09 Implementing queues with two stacks
TFIDF与BM25
Kali modify IP address
剑指 Offer 07. 重建二叉树
国信金太阳靠谱吗?是否合法?开股票账户安全吗?
Go build reports an error missing go sum entry for module providing package ... to add:
[paper notes] contextual transformer networks for visual recognition