当前位置:网站首页>在jupyter NoteBook使用Pytorch进行MNIST实现
在jupyter NoteBook使用Pytorch进行MNIST实现
2022-07-06 09:11:00 【一曲无痕奈何】
"流程 "
#1、加载必要的库
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
from torchvision import datasets , transforms
#2、定义超参数
BATCH_SIZE = 16 #每批处理的数据
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 10 #训练数据的轮次
#3、构建pipline,对图像做处理
pipline = transforms.Compose({
transforms.ToTensor(),#将图片转换为tensor
transforms.Normalize((0.1307),(0.3081)) #正则化,当模型过拟合时,降低模型复杂度
})
#4、下载加载数据
from torch.utils.data import DataLoader
train_set = datasets.MNIST("data2",train=True,download=True,transform=pipline)
test_set = datasets.MNIST("data2",train=False,download=True,transform=pipline)
#加载训练数据集
train_loader = DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)
#加载测试数据集
test_loader = DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
#显示其中的图片
with open("./data2/MNIST/raw/train-images-idx3-ubyte","rb") as f:
file = f.read()
image1 = [int(str(item).encode("ascii") ,16) for item in file[16:16+784]]
print(image1)
import cv2
import numpy as np
image1_np = np.array(image1,dtype = np.uint8).reshape(28,28,1)
print(image1_np.shape)
#保存图片
cv2.imwrite('digit.jpg',image1_np)
#5、构建网络模型
class Digit(nn.Module):
def __init__(self): #构造方法
super().__init__() #调用父类的构造方法,继承父类的属性
self.conv1 = nn.Conv2d(1,10,5) #输入通道为1,输出通道为10,卷积核为5(这是5*5的)
self.conv2 = nn.Conv2d(10,20,3) #上一层的出是下一层的输入
self.fc1 = nn.Linear(20*10*10,500) #20*10*10 整体的输入通道数, 500输出通道
self.fc2 = nn.Linear(500,10) #10总共10个类别的概率
def forward(self,x):
input_size = x.size(0) #整个图片的张量形式就是 batch_size*1*28*28 ,所以直接拿到batch_size
x = self.conv1(x) #输入:batch_size*1*28*28 输出:batch_size*10*24*24 这个10是第一个卷积层的输出通道数,24 = 28-5+1
x = F.relu(x) #保持shape不变,输出 batch_size*10*24*24
x = F.max_pool2d(x,2,2) #输入batch_size*10*24*24 输出:batch_size*10*12(减半)*12 #池化层:对图片进行压缩(降采样) 提取最显著的特征
x = self.conv2(x) #输入:batch_size*10*12*12 输出:batch_size*20*10*10(12-3+1)
x = F.relu(x)
x = x.view(input_size,-1) #拉伸,或者是拉平,这个-1自动计算维度 这个-1其实他的值为 20*10*10=2000的维度
x = self.fc1(x) #输入 batch_size*2000 输出batch_size*500
x = F.relu(x) #保持shape不变
x = self.fc2(x) #输入:batch_size*500 输出:batch_size*10
output = F.log_softmax(x,dim=1) #损失函数 计算分类后,每个数字的概率值
return output
#6、定义优化器
model =Digit().to(DEVICE)
optimizer = optim.Adam(model.parameters())
#7、定义训练方法
def train_model(model,device,train_loader,optimizer,epoch):
#模型训练
model.train()
for batch_index,(data,target) in enumerate(train_loader):
#部署到DEVICE上去
data,target = data.to(device),target.to(device)
#梯度初始化为0
optimizer.zero_grad()
#训练后的结果
output = model(data)
#计算损失
loss = F.cross_entropy(output, target) #交叉熵损失函数适用多分类任务
#找到概率值最大的下标
pred = output.max(1,keepdim = True) #1表示横轴 也可以这样写 pred = output.argmax(dim=1)
#反向传播
loss.backward()
#参数优化,也就是每一次的参数的更新
optimizer.step()
if batch_index % 3000 == 0: #每处理3000张图片就打印一次
print("Train Epoch :{}\tLOSS : {:.6f}".format(epoch,loss.item())) #这个loss后面必须加item(),拿到数值
#8、定义测试方法
def test_model(model,device,test_loader):
#模型验证
model.eval()
#正确率
correct = 0.0
#测试损失
test_loss = 0.0
with torch.no_grad(): #不进行梯度计算,也不会进行反向传播
for data, target in test_loader:
#部署到device上去
data,target = data.to(device),target.to(device)
#测试数据
output = model(data)
#计算测试损失
test_loss += F.cross_entropy(output, target).item()
#找到概率最大值的下标
pred = output.argmax(dim = 1)
#累计正确的值
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print("Test-- Average loss {:.4f},Accuracy : {:.3f}\n".format(
test_loss, 100.0 *correct / len(test_loader.dataset)
))
# 9、调用训练和测试方法
for epoch in range(1,EPOCHS + 1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)
边栏推荐
- Several errors encountered when installing opencv
- Docker MySQL solves time zone problems
- Embedded development is much more difficult than MCU? Talk about SCM and embedded development and design experience
- Use xtrabackup for MySQL database physical backup
- Southwest University: Hu hang - Analysis on learning behavior and learning effect
- Download address of canoe, download and activation of can demo 16, and appendix of all canoe software versions
- Vh6501 Learning Series
- MySQL的存储引擎
- 颜值爆表,推荐两款JSON可视化工具,配合Swagger使用真香
- MySQL combat optimization expert 12 what does the memory data structure buffer pool look like?
猜你喜欢
How to build an interface automation testing framework?
MySQL combat optimization expert 03 uses a data update process to preliminarily understand the architecture design of InnoDB storage engine
[Julia] exit notes - Serial
Mexican SQL manual injection vulnerability test (mongodb database) problem solution
15 医疗挂号系统_【预约挂号】
What should the redis cluster solution do? What are the plans?
寶塔的安裝和flask項目部署
如何搭建接口自动化测试框架?
再有人问你数据库缓存一致性的问题,直接把这篇文章发给他
使用OVF Tool工具从Esxi 6.7中导出虚拟机
随机推荐
MySQL combat optimization expert 03 uses a data update process to preliminarily understand the architecture design of InnoDB storage engine
Carolyn Rosé博士的社交互通演讲记录
13 medical registration system_ [wechat login]
14 医疗挂号系统_【阿里云OSS、用户认证与就诊人】
A necessary soft skill for Software Test Engineers: structured thinking
South China Technology stack cnn+bilstm+attention
The governor of New Jersey signed seven bills to improve gun safety
MySQL ERROR 1040: Too many connections
MySQL combat optimization expert 10 production experience: how to deploy visual reporting system for database monitoring system?
C杂讲 文件 续讲
MySQL combat optimization expert 04 uses the execution process of update statements in the InnoDB storage engine to talk about what binlog is?
Solution to the problem of cross domain inaccessibility of Chrome browser
软件测试工程师必备之软技能:结构化思维
[one click] it only takes 30s to build a blog with one click - QT graphical tool
[NLP] bert4vec: a sentence vector generation tool based on pre training
MySQL底层的逻辑架构
Target detection -- yolov2 paper intensive reading
Several errors encountered when installing opencv
Contest3145 - the 37th game of 2021 freshman individual training match_ B: Password
Time complexity (see which sentence is executed the most times)