当前位置:网站首页>机器学习之卷积神经网络Lenet5训练模型
机器学习之卷积神经网络Lenet5训练模型
2022-06-28 15:49:00 【华为云】
Lenet5训练模型
下载数据集

可以提前下载也可以在线下载
train_data = torchvision.datasets.MNIST(root='./',download=True,train=True,transform=transform)test_data = torchvision.datasets.MNIST(root='./',download=True,train=False,transform=transform)训练模型
import torchimport torchvisionclass Lenet5(torch.nn.Module): def __init__(self): super(Lenet5, self).__init__() self.model = torch.nn.Sequential( torch.nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5), # 1*32*32 # 6*28*28 torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, # 6*14*14 stride=2), torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), # 16 *10*10 torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2), # 16*5*5 torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5), # 120*1*1 torch.nn.ReLU(), torch.nn.Flatten(), # 展平 成120*1维 torch.nn.Linear(120, 84), torch.nn.Linear(84, 10) ) def forward(self,x): x = self.model(x) return xtransform = torchvision.transforms.Compose( [torchvision.transforms.Resize(32), torchvision.transforms.ToTensor()])train_data = torchvision.datasets.MNIST(root='./',download=True,train=True,transform=transform)test_data = torchvision.datasets.MNIST(root='./',download=True,train=False,transform=transform)#分批次加载数据 64 128train_loader =torch.utils.data.DataLoader(train_data,batch_size=64,shuffle=True)test_loader =torch.utils.data.DataLoader(test_data,batch_size=64,shuffle=True)#gpudevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = Lenet5().to(device)loss_func = torch.nn.CrossEntropyLoss().to(device)optim = torch.optim.Adam(net.parameters(),lr=0.001)net.train()for epoch in range(10): for step,(x,y) in enumerate(train_loader): x = x.to(device) y = y.to(device) ouput = net(x) loss = loss_func(ouput,y) #计算损失 optim.zero_grad() loss.backward() optim.step() print('epoch:',epoch,"loss:",loss)torch.save(net,'net.pkl')
import torchimport torchvisionclass Lenet5(torch.nn.Module): def __init__(self): super(Lenet5, self).__init__() self.model = torch.nn.Sequential( torch.nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5), # 1*32*32 # 6*28*28 torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, # 6*14*14 stride=2), torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), # 16 *10*10 torch.nn.ReLU(), torch.nn.MaxPool2d(kernel_size=2, stride=2), # 16*5*5 torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5), # 120*1*1 torch.nn.ReLU(), torch.nn.Flatten(), # 展平 成120*1维 torch.nn.Linear(120, 84), torch.nn.Linear(84, 10) ) def forward(self,x): x = self.model(x) return xtransform = torchvision.transforms.Compose( [torchvision.transforms.Resize(32), torchvision.transforms.ToTensor()])test_data = torchvision.datasets.MNIST(root='./',download=False,train=False,transform=transform)test_loader =torch.utils.data.DataLoader(test_data,batch_size=64,shuffle=True)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = torch.load('net.pkl')net.eval() #表明进行推理with torch.no_grad(): for step,(x,y) in enumerate(test_loader): x,y = x.to(device),y.to(device) pre = net(x) print(pre) pre_y = torch.max(pre.cpu(),1)[1].numpy() print(pre_y) y = y.cpu().numpy() acc = (pre_y == y).sum()/len(y) print("accu:",acc)
边栏推荐
- Deep learning convolutional neural network of machine learning to realize handwritten font recognition based on CNN network
- go-zero 微服务实战系列(七、请求量这么高该如何优化)
- 【LeetCode】13、罗马数字转整数
- ROS knowledge points - ROS create workspace
- openGauss内核:SQL解析过程分析
- Qt5.5.1 configuring msvc2010 compiler and WinDbg debugger
- VC2010 编绎Qt5.6.3 提示 CVTRES : fatal error CVT1107:
- 隆重推出 Qodana:您最爱的 CI 的代码质量平台
- 开源技术交流丨一站式全自动化运维管家ChengYing入门介绍
- 大神详解开源 BUFF 增益攻略丨直播讲座
猜你喜欢

Technical secrets of ByteDance data platform: implementation and optimization of complex query based on Clickhouse
![[recommendation system] esmm model of multi task learning (updating)](/img/21/8e38d3903eb1110efc4773edb2d09c.png)
[recommendation system] esmm model of multi task learning (updating)

Azure Kinect微软摄像头Unity开发小结

Privacy computing fat - offline prediction

Realization of a springboard machine

C语言基础语法

10:00面试,10:02就出来了 ,问的实在是太...
![[leetcode] 13. Roman numeral to integer](/img/3c/7c57d0c407f5302115f69f44b473c5.png)
[leetcode] 13. Roman numeral to integer

Web3.0时代来了,看天翼云存储资源盘活系统如何赋能新基建(上)

10 years of testing experience, worthless in the face of the physiological age of 35
随机推荐
Classic model transformer
Introduction to deep learning in machine learning
平台即代码的未来是Kubernetes扩展
关于针对tron API签名广播时使用curl的json解析问题解决方案及针对json.loads方法的问题记录
首次失败后,爱美客第二次冲刺港交所上市,财务负责人变动频繁
Navicat 15 for MySQL
开源大咖说 - Linus 与 Jim 对话中国开源
What are the most powerful small and medium-sized companies in Beijing?
Do not use short circuit logic to write STL sorter multi condition comparison
3. caller service call - dapr
看界面控件DevExpress WinForms如何创建一个虚拟键盘
IPDK — Overview
Flutter简单实现多语言国际化
Application of mongodb in Tencent retail premium code
Deep learning convolutional neural network of machine learning to realize handwritten font recognition based on CNN network
Operating excel with openpyxl
隐私计算 FATE - 离线预测
深度学习基础汇总
OpenHarmony—内核对象事件之源码详解
Qt5.5.1 configuring msvc2010 compiler and WinDbg debugger