当前位置:网站首页>机器学习之卷积神经网络Lenet5训练模型
机器学习之卷积神经网络Lenet5训练模型
2022-06-28 15:49:00 【华为云】
Lenet5训练模型
下载数据集

可以提前下载也可以在线下载
train_data = torchvision.datasets.MNIST(root='./',download=True,train=True,transform=transform)test_data = torchvision.datasets.MNIST(root='./',download=True,train=False,transform=transform)训练模型
import torchimport torchvisionclass Lenet5(torch.nn.Module): def __init__(self): super(Lenet5, self).__init__() self.model = torch.nn.Sequential( torch.nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5), # 1*32*32 # 6*28*28 torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, # 6*14*14 stride=2), torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), # 16 *10*10 torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2), # 16*5*5 torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5), # 120*1*1 torch.nn.ReLU(), torch.nn.Flatten(), # 展平 成120*1维 torch.nn.Linear(120, 84), torch.nn.Linear(84, 10) ) def forward(self,x): x = self.model(x) return xtransform = torchvision.transforms.Compose( [torchvision.transforms.Resize(32), torchvision.transforms.ToTensor()])train_data = torchvision.datasets.MNIST(root='./',download=True,train=True,transform=transform)test_data = torchvision.datasets.MNIST(root='./',download=True,train=False,transform=transform)#分批次加载数据 64 128train_loader =torch.utils.data.DataLoader(train_data,batch_size=64,shuffle=True)test_loader =torch.utils.data.DataLoader(test_data,batch_size=64,shuffle=True)#gpudevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = Lenet5().to(device)loss_func = torch.nn.CrossEntropyLoss().to(device)optim = torch.optim.Adam(net.parameters(),lr=0.001)net.train()for epoch in range(10): for step,(x,y) in enumerate(train_loader): x = x.to(device) y = y.to(device) ouput = net(x) loss = loss_func(ouput,y) #计算损失 optim.zero_grad() loss.backward() optim.step() print('epoch:',epoch,"loss:",loss)torch.save(net,'net.pkl')
import torchimport torchvisionclass Lenet5(torch.nn.Module): def __init__(self): super(Lenet5, self).__init__() self.model = torch.nn.Sequential( torch.nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5), # 1*32*32 # 6*28*28 torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, # 6*14*14 stride=2), torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), # 16 *10*10 torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2), # 16*5*5 torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5), # 120*1*1 torch.nn.ReLU(), torch.nn.Flatten(), # 展平 成120*1维 torch.nn.Linear(120, 84), torch.nn.Linear(84, 10) ) def forward(self,x): x = self.model(x) return xtransform = torchvision.transforms.Compose( [torchvision.transforms.Resize(32), torchvision.transforms.ToTensor()])test_data = torchvision.datasets.MNIST(root='./',download=False,train=False,transform=transform)test_loader =torch.utils.data.DataLoader(test_data,batch_size=64,shuffle=True)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = torch.load('net.pkl')net.eval() #表明进行推理with torch.no_grad(): for step,(x,y) in enumerate(test_loader): x,y = x.to(device),y.to(device) pre = net(x) print(pre) pre_y = torch.max(pre.cpu(),1)[1].numpy() print(pre_y) y = y.cpu().numpy() acc = (pre_y == y).sum()/len(y) print("accu:",acc)
边栏推荐
- 开源大咖说 - Linus 与 Jim 对话中国开源
- 字节跳动数据平台技术揭秘:基于 ClickHouse 的复杂查询实现与优化
- Fleet |「後臺探秘」第 3 期:狀態管理
- 一种跳板机的实现思路
- 面试官: 线程池是如何做到线程复用的?有了解过吗,说说看
- Grand launch of qodana: your favorite CI code quality platform
- Visual Studio 2010 configuring and using qt5.6.3
- [leetcode] 13. Roman numeral to integer
- How can the digital intelligent supply chain management platform of the smart Park optimize process management and drive the development of the park to increase speed and quality?
- 【LeetCode】13、罗马数字转整数
猜你喜欢

开源技术交流丨一站式全自动化运维管家ChengYing入门介绍

字节跳动数据平台技术揭秘:基于 ClickHouse 的复杂查询实现与优化

经典模型——Transformer

Visual Studio 2019软件安装包和安装教程

wallys/DR7915-wifi6-MT7915-MT7975-2T2R-support-OpenWRT-802.11AX-supporting-MiniPCIe-Module

讲师征集令 | Apache DolphinScheduler Meetup分享嘉宾,期待你的议题和声音!
![数组中的第K大元素[堆排 + 建堆的实际时间复杂度]](/img/69/bcafdcb09ffbf87246a03bcb9367aa.png)
数组中的第K大元素[堆排 + 建堆的实际时间复杂度]

What! One command to get the surveillance?

3. caller service call - dapr

平台即代码的未来是Kubernetes扩展
随机推荐
NAACL 2022 | 机器翻译SOTA模型的蒸馏
Basic grammar of C language
10 years of testing experience, worthless in the face of the physiological age of 35
有哪些好用的供应商管理系统
The past and present life of distributed cap theorem
【Spock】处理 Non-ASCII characters in an identifier
零钱兑换(动态规划)
Flutter dart语言特点总结
Lecturer solicitation order | Apache dolphin scheduler meetup sharing guests, looking forward to your topic and voice!
知道这几个命令让你掌握Shell自带工具
REDIS00_详解redis.conf配置文件
Opengauss kernel: analysis of SQL parsing process
ROS knowledge points - build an ROS development environment using vscode
Focus on the 35 year old Kan: fear is because you don't have the ability to match your age
【高并发基础】MySQL 不同事务隔离级别下的并发隐患及解决方案
隆重推出 Qodana:您最爱的 CI 的代码质量平台
开源技术交流丨一站式全自动化运维管家ChengYing入门介绍
【初学者必看】vlc实现的rtsp服务器及转储H264文件
A little hesitant in the morning
The k-th element in the array [heap row + actual time complexity of heap building]