当前位置:网站首页>机器学习之深度学习卷积神经网络,实现基于CNN网络的手写字体识别
机器学习之深度学习卷积神经网络,实现基于CNN网络的手写字体识别
2022-06-28 15:24:00 【华为云】
实现基于CNN网络的手写字体识别
首先下载数据
1、搭建CNN网络模型;
class CNN(nn.Module): def __init__(self): super(CNN,self).__init__() ''' 一般来说,卷积网络包括以下内容: 1.卷积层 2.神经网络 3.池化层 ''' self.conv1=nn.Sequential( nn.Conv2d( #--> (1,28,28) in_channels=1, #传入的图片是几层的,灰色为1层,RGB为三层 out_channels=16, #输出的图片是几层 kernel_size=5, #代表扫描的区域点为5*5 stride=1, #就是每隔多少步跳一下 padding=2, #边框补全,其计算公式=(kernel_size-1)/2=(5-1)/2=2 ), # 2d代表二维卷积 --> (16,28,28) nn.ReLU(), #非线性激活层 nn.MaxPool2d(kernel_size=2), #设定这里的扫描区域为2*2,且取出该2*2中的最大值 --> (16,14,14) ) self.conv2=nn.Sequential( nn.Conv2d( # --> (16,14,14) in_channels=16, #这里的输入是上层的输出为16层 out_channels=32, #在这里我们需要将其输出为32层 kernel_size=5, #代表扫描的区域点为5*5 stride=1, #就是每隔多少步跳一下 padding=2, #边框补全,其计算公式=(kernel_size-1)/2=(5-1)/2= ), # --> (32,14,14) nn.ReLU(), nn.MaxPool2d(kernel_size=2), #设定这里的扫描区域为2*2,且取出该2*2中的最大值 --> (32,7,7),这里是三维数据 ) self.out=nn.Linear(32*7*7,10) #注意一下这里的数据是二维的数据 def forward(self,x): x=self.conv1(x) x=self.conv2(x) #(batch,32,7,7) #然后接下来进行一下扩展展平的操作,将三维数据转为二维的数据 x=x.view(x.size(0),-1) #(batch ,32 * 7 * 7) output=self.out(x) return output2、设计损失函数,选择优化函数;
# 添加优化方法optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)# 指定损失函数使用交叉信息熵loss_fn=nn.CrossEntropyLoss()3、实现模型训练与测试。
step=0for epoch in range(EPOCH): #加载训练数据 for step,data in enumerate(train_loader): x,y=data #分别得到训练数据的x和y的取值 b_x=Variable(x) b_y=Variable(y) output=cnn(b_x) #调用模型预测 loss=loss_fn(output,b_y)#计算损失值 optimizer.zero_grad() #每一次循环之前,将梯度清零 loss.backward() #反向传播 optimizer.step() #梯度下降 #每执行50次,输出一下当前epoch、loss、accuracy if (step%50==0): #计算一下模型预测正确率 test_output=cnn(test_x) y_pred=torch.max(test_output,1)[1].data.squeeze() accuracy=sum(y_pred==test_y).item()/test_y.size(0) print('now epoch : ', epoch, ' | loss : %.4f ' % loss.item(), ' | accuracy : ' , accuracy)
代码:
import torchimport torch.nn as nnfrom torch.autograd import Variableimport torch.utils.data as Dataimport torchvision#Hyper prametersEPOCH=1BATCH_SIZE=50LR=0.001DOWNLOAD_MNIST=Falsetrain_data = torchvision.datasets.MNIST( root='./mnist', train=True, transform=torchvision.transforms.ToTensor(), #将下载的文件转换成pytorch认识的tensor类型,且将图片的数值大小从(0-255)归一化到(0-1) download=DOWNLOAD_MNIST)train_loader=Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)test_data=torchvision.datasets.MNIST( root='./mnist', train=False,)with torch.no_grad(): test_x=Variable(torch.unsqueeze(test_data.data, dim=1)).type(torch.FloatTensor)[:2000]/255 #只取前两千个数据吧,差不多已经够用了,然后将其归一化。 test_y=test_data.targets[:2000]'''开始建立CNN网络'''class CNN(nn.Module): def __init__(self): super(CNN,self).__init__() ''' 一般来说,卷积网络包括以下内容: 1.卷积层 2.神经网络 3.池化层 ''' self.conv1=nn.Sequential( nn.Conv2d( #--> (1,28,28) in_channels=1, #传入的图片是几层的,灰色为1层,RGB为三层 out_channels=16, #输出的图片是几层 kernel_size=5, #代表扫描的区域点为5*5 stride=1, #就是每隔多少步跳一下 padding=2, #边框补全,其计算公式=(kernel_size-1)/2=(5-1)/2=2 ), # 2d代表二维卷积 --> (16,28,28) nn.ReLU(), #非线性激活层 nn.MaxPool2d(kernel_size=2), #设定这里的扫描区域为2*2,且取出该2*2中的最大值 --> (16,14,14) ) self.conv2=nn.Sequential( nn.Conv2d( # --> (16,14,14) in_channels=16, #这里的输入是上层的输出为16层 out_channels=32, #在这里我们需要将其输出为32层 kernel_size=5, #代表扫描的区域点为5*5 stride=1, #就是每隔多少步跳一下 padding=2, #边框补全,其计算公式=(kernel_size-1)/2=(5-1)/2= ), # --> (32,14,14) nn.ReLU(), nn.MaxPool2d(kernel_size=2), #设定这里的扫描区域为2*2,且取出该2*2中的最大值 --> (32,7,7),这里是三维数据 ) self.out=nn.Linear(32*7*7,10) #注意一下这里的数据是二维的数据 def forward(self,x): x=self.conv1(x) x=self.conv2(x) #(batch,32,7,7) #然后接下来进行一下扩展展平的操作,将三维数据转为二维的数据 x=x.view(x.size(0),-1) #(batch ,32 * 7 * 7) output=self.out(x) return output cnn=CNN()# print(cnn)# 添加优化方法optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)# 指定损失函数使用交叉信息熵loss_fn=nn.CrossEntropyLoss()'''开始训练我们的模型哦'''step=0for epoch in range(EPOCH): #加载训练数据 for step,data in enumerate(train_loader): x,y=data #分别得到训练数据的x和y的取值 b_x=Variable(x) b_y=Variable(y) output=cnn(b_x) #调用模型预测 loss=loss_fn(output,b_y)#计算损失值 optimizer.zero_grad() #每一次循环之前,将梯度清零 loss.backward() #反向传播 optimizer.step() #梯度下降 #每执行50次,输出一下当前epoch、loss、accuracy if (step%50==0): #计算一下模型预测正确率 test_output=cnn(test_x) y_pred=torch.max(test_output,1)[1].data.squeeze() accuracy=sum(y_pred==test_y).item()/test_y.size(0) print('now epoch : ', epoch, ' | loss : %.4f ' % loss.item(), ' | accuracy : ' , accuracy)'''打印十个测试集的结果'''test_output=cnn(test_x[:10])y_pred=torch.max(test_output,1)[1].data.squeeze() #选取最大可能的数值所在的位置print(y_pred.tolist(),'predecton Result')print(test_y[:10].tolist(),'Real Result')边栏推荐
- 厨卫电器行业S2B2C系统网站解决方案:打造S2B2C平台全渠道商业系统
- R language uses the multinom function of NNET package to build an unordered multi classification logistic regression model, and uses regression coefficients and their standard errors to calculate the
- 华为能成“口红一哥”,或者“带货女王”吗?
- Successful cases of rights protection of open source projects: successful rights protection of SPuG open source operation and maintenance platform
- 抽奖动画 - 鲤鱼跳龙门
- SQL statement exercises
- Opengauss kernel: analysis of SQL parsing process
- R language ggplot2 visualization: use the patchwork package (directly use the plus sign +) to horizontally combine a ggplot2 visualization result and a plot function visualization result to form a fin
- 隆重推出 Qodana:您最爱的 CI 的代码质量平台
- ORACLE中dbms_output.put_line输出问题的解决过程
猜你喜欢

With 120billion yuan, she will ring the bell for IPO again

halcon 基础总结(一)裁切图片并旋转图像

Expand Disk C (allocate the memory of disk d to Disk C)

Leetcode 48. Rotate image (yes, resolved)

Jackie Chan and fast brand, who is the Savior of Kwai?

Application of mongodb in Tencent retail premium code

Technical secrets of ByteDance data platform: implementation and optimization of complex query based on Clickhouse

Complete model training routine (I)

智慧园区数智化供应链管理平台如何优化流程管理,驱动园区发展提速增质?

New offline retail stores take off against the trend, and consumption enthusiasm under the dark cloud of inflation
随机推荐
R语言ggplot2可视化:使用patchwork包(直接使用加号+)将一个ggplot2可视化结果和一段文本内容横向组合起来形成最终结果图
张同学还没学会当主播
环保产品“绿色溢价”高?低碳生活方式离人们还有多远
GBASE南大通用亮相第六届世界智能大会
Xinchuang operating system -- kylin kylin desktop operating system (project 10 security center)
PostgreSQL实现按年、月、日、周、时、分、秒的分组统计
最长连续序列
一个bug肝一周...忍不住提了issue
How does Seata server 1.5.0 support mysql8.0?
从莫高窟到太平洋,海量数据找到了新家园
R language ggplot2 visualization: use the patchwork package to stack two ggplot2 visualization results vertically to form a composite diagram, and stack one visualization result on the other visualiza
R language ggplot2 visualization: the patchwork package horizontally combines a ggplot2 visualization result and a plot function visualization result to form a final result graph, aligns the two visua
ORACLE中dbms_output.put_line输出问题的解决过程
Smart supplier management system for chemical manufacturing industry deeply explores the field of supplier management and improves supply chain collaboration
High "green premium" of environmental protection products? How far is the low-carbon lifestyle from people
一种跳板机的实现思路
实验6 8255并行接口实验【微机原理】【实验】
R language uses the multinom function of NNET package to build an unordered multi classification logistic regression model, and uses regression coefficients and their standard errors to calculate the
Expand Disk C (allocate the memory of disk d to Disk C)
Opengauss kernel: analysis of SQL parsing process