当前位置:网站首页>Understand the construction of the entire network model
Understand the construction of the entire network model
2022-06-28 20:29:00 【seven_ Not Severn】
# Prepare the dataset
# Dataset processing
# Divide the data set
# utilize DataLoader Load data set
eg:dataloder_train = DataLoader(train_data, batch_size=64, drop_last=False)
dataloder_test = DataLoader(test_data, batch_size=64, drop_last=False)
# Building neural networks You can put one separately model.py file ( stay model To test the accuracy of our model .)
import torch
from torch import nn
# Building neural networks
# Be careful model Inside Sequential Refer to the , No more external import 了 , Otherwise, an error will be reported : therefore ==》from model import *
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(in_features=1024, out_features=64),
nn.Linear(in_features=64, out_features=10)
)
def forward(self, x):
x = self.model(x)
return x
# Check for accuracy
if __name__ == '__main__':
test = Test()
input = torch.ones((64, 3, 32, 32))
output = test(input)
print(output.shape)
# Create a network model
test = Test()
# Loss function
loss_fc = nn.CrossEntropyLoss()
# Optimizer
learning_rate = 1e-2
optimizer = torch.optim.SGD(test.parameters(), lr=learning_rate)
## add to tensorboard. hold loss The changing image of
writer = SummaryWriter("logs")
# Understanding of some parameters
step Training times
epoch Number of training rounds
'------------- Train the model and test whether the model is well trained ----------------------'
# Start training :
for i in range(epoch):
print("---------- The first {} Round of training begins :-------------".format(i+1))
# The training steps begin
for data in dataloder_train:
imgs, targets = data
output_train = test(imgs)
loss = loss_fc(output_train, targets) # Gain loss value loss
# Use the optimizer to optimize the model
optimizer.zero_grad() # Gradient clear
loss.backward() # Call loss loss, The gradient of each parameter is obtained
optimizer.step() # Call optimizer optimizer Optimize our parameters
total_train_step = total_train_step + 1 # Record the number of workouts
if total_train_step % 100 == 0: # This reduces the amount of display
# loss.item() And loss Sometimes there are differences ,loss.item() The number returned is
print(" Training times :{}, Loss value :{}".format(total_train_step, loss.item()))
writer.add_scalar("train_loss", loss.item(), total_train_step) # Meet 100 Integer record of
# When the model is training , In order to know if the model is well trained , We do a test after each round of training , The loss value of the test data set is used to evaluate whether the model is well trained , There is no need for the optimizer to perform tuning during testing
# The test step begins
# Loss value
total_test_loss = 0
# Right value
total_accuracy = 0
with torch.no_grad(): # It means that with Code in , Its gradient is gone , No tuning is guaranteed
for data in dataloder_test:
imgs, targets = data
output_test = test(imgs)v
# This loss It's just data Part of the data , Losses in the network model
loss = loss_fc(output_test, targets)
# We require... On the entire data set loss. So here's a total loss
total_test_loss = total_test_loss + loss
accuracy = (output_test.argmax(1) == targets).sum() # Calculate forecast and actual Consistent number ##(1) It means to search for the largest number of rows
total_accuracy = total_accuracy + accuracy # Total correct number
print(" Loss value of the overall test set :{}".format(total_test_loss.item()))
print(" The overall test accuracy is :{}".format(total_accuracy/test_data_size))
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1
# Save the model for each round of training
torch.save(test, "test_{}.pth".format(i)) # Method to save
# Here is method 2 to save the model , Save the parameters as word type
# torch.save(test.state_dict(), "test_{}".format(i))
print(" Model saved ")
# It is explained here that the model is saved once for each round of training , After that, we can take each saved model for training to see the results , In fact, it's different
writer.close()
# draw loss perhaps accuracy The image changes :
It can be used writer = SummaryWriter("logs") perhaps plt.figure Methods
' The first one is : It is directly the numerical value trained by our model '
writer = SummaryWriter("logs")
...
writer.add_scalar("train_loss", loss.item(), total_train_step)
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
' The second kind , And I need to change it '
become list? still numpy What is that?
边栏推荐
- No module named ‘PyEMD‘ ; Use plt figure()TypeError: ‘module‘ object is not callable
- 国产数据库名录一览
- two thousand three hundred and forty-two
- 裁员真能拯救中国互联网?
- 压缩与解压缩命令
- odoo15 Module operations are not possible at this time, please try again later or contact your syste
- Visualization of neural network structure in different frames
- Various types of long
- Compression and decompression commands
- [learning notes] factor analysis
猜你喜欢

2022 welder (elementary) special operation certificate examination question bank and answers

不同框架的绘制神经网络结构可视化

Day88. qiniu cloud: upload house source pictures and user avatars

28 rounds of interviews with 10 companies in two and a half years

bluecmsv1.6代码审计

数据标准化处理

电子科大(申恒涛团队)&京东AI(梅涛团队)提出用于视频问答的结构化双流注意网络,性能SOTA!优于基于双视频表示的方法!...

Flatten of cnn-lstm

Rsync remote synchronization

基于 Apache APISIX 的自动化运维平台
随机推荐
Jenkins pipeline's handling of job parameters
Win 10 create a gin framework project
[go language questions] go from 0 to entry 5: comprehensive review of map, conditional sentences and circular sentences
Internship: term understanding and handwriting interface
resilience4j 重试源码分析以及重试指标采集
Use of WC command
基于 Apache APISIX 的自动化运维平台
Various types of long
Input and output real data
Day88.七牛云: 房源图片、用户头像上传
How to do a good job in customer's successful bottom design | tob Master Course
1. 整合 Servlet
csdn涨薪技术-Selenium自动化测试全栈总结
如何添加 logs来debug ANR 问题
Pipeline | and redirection >
还在付费下论文吗?快来跟我一起白piao知网
Are you still paying for your thesis? Come and join me
ComparisonChain-文件名排序
【学习笔记】因子分析
T检验(检验两个总体的均值差异是否显著)