当前位置:网站首页>理解整个网络模型的构建
理解整个网络模型的构建
2022-06-28 20:04:00 【seven_不是赛文】
# 准备数据集
# 数据集处理
# 进行数据集划分
# 利用DataLoader 加载数据集
eg:dataloder_train = DataLoader(train_data, batch_size=64, drop_last=False)
dataloder_test = DataLoader(test_data, batch_size=64, drop_last=False)
#搭建神经网络 可以单独放一个model.py文件(在model中测试我们模型的准确性。)
import torch
from torch import nn
# 搭建神经网络
# 注意model里面的Sequential引用了,外部就不要再import了,否则会报错:所以==》from model import *
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(in_features=1024, out_features=64),
nn.Linear(in_features=64, out_features=10)
)
def forward(self, x):
x = self.model(x)
return x
# 检查准确性
if __name__ == '__main__':
test = Test()
input = torch.ones((64, 3, 32, 32))
output = test(input)
print(output.shape)
# 创建网络模型
test = Test()
# 损失函数
loss_fc = nn.CrossEntropyLoss()
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(test.parameters(), lr=learning_rate)
## 添加 tensorboard。把loss的变化图像表示出来
writer = SummaryWriter("logs")
# 一些参数的理解
step 训练的次数
epoch 训练的轮数
'-------------训练模型并测试模型训练好没有----------------------'
# 开始训练:
for i in range(epoch):
print("----------第{}轮训练开始:-------------".format(i+1))
# 训练步骤开始
for data in dataloder_train:
imgs, targets = data
output_train = test(imgs)
loss = loss_fc(output_train, targets) # 获得损失值 loss
# 使用优化器优化模型
optimizer.zero_grad() # 梯度清零
loss.backward() # 调用损失 loss,得到每一个参数的梯度
optimizer.step() #调用优化器 optimizer 对我们的参数进行优化
total_train_step = total_train_step + 1 #记录训练次数
if total_train_step % 100 == 0: #这样可以减少显示量
# loss.item()与loss时有区别的,loss.item()返回的是数字
print("训练次数:{}, 损失值:{}".format(total_train_step, loss.item()))
writer.add_scalar("train_loss", loss.item(), total_train_step) # 逢100的整数记录
# 模型在训练的时候,为了知道模型是否训练好,我们在每一轮训练完之后都进行一个测试,用测试数据集的损失值来评估模型有没有训练好,测试的时候不需要优化器进行调优
# 测试步骤开始
# 损失值
total_test_loss = 0
# 正确值
total_accuracy = 0
with torch.no_grad(): # 表示在 with 里的代码,它的梯度就没有了,保证不会进行调优
for data in dataloder_test:
imgs, targets = data
output_test = test(imgs)v
# 这个loss它只是data的一部分数据,在网络模型中的损失
loss = loss_fc(output_test, targets)
# 我们要求整个数据集上的loss.所以下面有个total loss
total_test_loss = total_test_loss + loss
accuracy = (output_test.argmax(1) == targets).sum() # 计算预测与实际 一致的个数 ##(1)是指对每一行搜索最大
total_accuracy = total_accuracy + accuracy # 总的正确的个数
print("整体测试集的损失值:{}".format(total_test_loss.item()))
print("整体测试的正确率为:{}".format(total_accuracy/test_data_size))
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1
# 保存每一轮训练的模型
torch.save(test, "test_{}.pth".format(i)) #方法以保存
# 下面是方法二保存模型,将参数保存成字典型
# torch.save(test.state_dict(), "test_{}".format(i))
print("模型已保存")
#这里说明每一轮训练都保存了一次模型,我们之后可以把每次保存的模型拿去训练看看结果,其实是不一样的
writer.close()
# 画loss或者accuracy的图像变化:
可以用writer = SummaryWriter("logs")或者plt.figure的方法
'第一种:直接就是我们模型训练出来的数值'
writer = SummaryWriter("logs")
...
writer.add_scalar("train_loss", loss.item(), total_train_step)
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
'第二种,还要转换一下'
变成list?还是numpy那个啥
边栏推荐
- How strong a mathematical foundation does deep learning need?
- easypoi
- 2022年P气瓶充装考试练习题及在线模拟考试
- 裁员真能拯救中国互联网?
- Troubleshooting of pyinstaller failed to pack pikepdf
- 【Go语言刷题篇】Go从0到入门5:Map综合复习、条件语句、循环语句练习
- Markdown mermaid种草(1)_ mermaid简介
- [go language questions] go from 0 to entry 5: comprehensive review of map, conditional sentences and circular sentences
- 【毕业季·进击的技术er】努力只能及格,拼命才能优秀!
- internship:术语了解及着手写接口
猜你喜欢

计网 | 一文解析TCP协议所有知识点

Markdown mermaid种草(1)_ mermaid简介

Echart: category text position adjustment of horizontal histogram

C # connect to the database to complete the operation of adding, deleting, modifying and querying

Windows 64 bit download install my SQL

【学习笔记】主成分分析法介绍

Software supply chain security risk guide for enterprise digitalization and it executives

rsync远程同步

《数据安全法》出台一周年,看哪四大变化来袭?

Markdown Mermaid Grass (1) Introduction à Mermaid
随机推荐
2022焊工(初级)特种作业证考试题库及答案
How to "calculate" in the age of computing power? The first mover advantage of "convergence of computing and networking" is very important!
输入和输出字符型数据
关键字long
Shell reads the value of the JSON file
Number theory -- detailed proof of Euler function, sieve method for Euler function, Euler theorem and Fermat theorem
Kaggle gastrointestinal image segmentation competition baseline
Ali open source (easyexcel)
ThreadLocal原理
2022年T电梯修理考试题库模拟考试平台操作
Comparisonchain file name sort
JVM memory structure
【Go语言刷题篇】Go从0到入门5:Map综合复习、条件语句、循环语句练习
Past blue bridge cup test questions ants catch cold
Software supply chain security risk guide for enterprise digitalization and it executives
The severity code indicates that the project file line prohibits the display of status errors. C1047 object or library file ".Lib" is different from other objects (such as "x64\release\main.obj")
Pyinstaller打包pikepdf失败的问题排查
How to analyze the relationship between enterprise digital transformation and data asset management?
方 差 分 析
Why does next() in iterator need to be forcibly converted?