当前位置:网站首页>Understand the construction of the entire network model
Understand the construction of the entire network model
2022-06-28 20:29:00 【seven_ Not Severn】
# Prepare the dataset
# Dataset processing
# Divide the data set
# utilize DataLoader Load data set
eg:dataloder_train = DataLoader(train_data, batch_size=64, drop_last=False)
dataloder_test = DataLoader(test_data, batch_size=64, drop_last=False)
# Building neural networks You can put one separately model.py file ( stay model To test the accuracy of our model .)
import torch
from torch import nn
# Building neural networks
# Be careful model Inside Sequential Refer to the , No more external import 了 , Otherwise, an error will be reported : therefore ==》from model import *
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(in_features=1024, out_features=64),
nn.Linear(in_features=64, out_features=10)
)
def forward(self, x):
x = self.model(x)
return x
# Check for accuracy
if __name__ == '__main__':
test = Test()
input = torch.ones((64, 3, 32, 32))
output = test(input)
print(output.shape)
# Create a network model
test = Test()
# Loss function
loss_fc = nn.CrossEntropyLoss()
# Optimizer
learning_rate = 1e-2
optimizer = torch.optim.SGD(test.parameters(), lr=learning_rate)
## add to tensorboard. hold loss The changing image of
writer = SummaryWriter("logs")
# Understanding of some parameters
step Training times
epoch Number of training rounds
'------------- Train the model and test whether the model is well trained ----------------------'
# Start training :
for i in range(epoch):
print("---------- The first {} Round of training begins :-------------".format(i+1))
# The training steps begin
for data in dataloder_train:
imgs, targets = data
output_train = test(imgs)
loss = loss_fc(output_train, targets) # Gain loss value loss
# Use the optimizer to optimize the model
optimizer.zero_grad() # Gradient clear
loss.backward() # Call loss loss, The gradient of each parameter is obtained
optimizer.step() # Call optimizer optimizer Optimize our parameters
total_train_step = total_train_step + 1 # Record the number of workouts
if total_train_step % 100 == 0: # This reduces the amount of display
# loss.item() And loss Sometimes there are differences ,loss.item() The number returned is
print(" Training times :{}, Loss value :{}".format(total_train_step, loss.item()))
writer.add_scalar("train_loss", loss.item(), total_train_step) # Meet 100 Integer record of
# When the model is training , In order to know if the model is well trained , We do a test after each round of training , The loss value of the test data set is used to evaluate whether the model is well trained , There is no need for the optimizer to perform tuning during testing
# The test step begins
# Loss value
total_test_loss = 0
# Right value
total_accuracy = 0
with torch.no_grad(): # It means that with Code in , Its gradient is gone , No tuning is guaranteed
for data in dataloder_test:
imgs, targets = data
output_test = test(imgs)v
# This loss It's just data Part of the data , Losses in the network model
loss = loss_fc(output_test, targets)
# We require... On the entire data set loss. So here's a total loss
total_test_loss = total_test_loss + loss
accuracy = (output_test.argmax(1) == targets).sum() # Calculate forecast and actual Consistent number ##(1) It means to search for the largest number of rows
total_accuracy = total_accuracy + accuracy # Total correct number
print(" Loss value of the overall test set :{}".format(total_test_loss.item()))
print(" The overall test accuracy is :{}".format(total_accuracy/test_data_size))
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1
# Save the model for each round of training
torch.save(test, "test_{}.pth".format(i)) # Method to save
# Here is method 2 to save the model , Save the parameters as word type
# torch.save(test.state_dict(), "test_{}".format(i))
print(" Model saved ")
# It is explained here that the model is saved once for each round of training , After that, we can take each saved model for training to see the results , In fact, it's different
writer.close()
# draw loss perhaps accuracy The image changes :
It can be used writer = SummaryWriter("logs") perhaps plt.figure Methods
' The first one is : It is directly the numerical value trained by our model '
writer = SummaryWriter("logs")
...
writer.add_scalar("train_loss", loss.item(), total_train_step)
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
' The second kind , And I need to change it '
become list? still numpy What is that?
边栏推荐
- How to understand the usability of cloud native databases?
- Xiaobai's e-commerce business is very important to choose the right mall system!
- SaaS sales upgrade under the new situation | tob Master Course
- QSP读取标签配置错误问题
- [algorithm] I brushed two big factory interview questions and learned array again with tears in my eyes“
- A few lines of code can realize complex excel import and export. This tool class is really powerful!
- resilience4j 重试源码分析以及重试指标采集
- 03.hello_rust
- Visualization of neural network structure in different frames
- Racher add / delete node
猜你喜欢

Win 10 create a gin framework project

mysql-发生系统错误1067

bluecmsv1.6代码审计

【Try to Hack】Cobalt Strike(一)

电子科大(申恒涛团队)&京东AI(梅涛团队)提出用于视频问答的结构化双流注意网络,性能SOTA!优于基于双视频表示的方法!...

还在付费下论文吗?快来跟我一起白piao知网

Visualization of neural network structure in different frames

Rsync remote synchronization

30讲 线性代数 第四讲 线性方程组

Software supply chain security risk guide for enterprise digitalization and it executives
随机推荐
Lecture 30 linear algebra Lecture 4 linear equations
SQL server2019 create a new SQL server authentication user name and log in
市值1200亿美金,老牌财税巨头Intuit是如何做到的?
【算法篇】刷了两道大厂面试题,含泪 ”重学数组“
A few lines of code can realize complex excel import and export. This tool class is really powerful!
bluecmsv1.6代码审计
Is it safe for CICC fortune to open an account? Let's talk about CICC fortune
ROS中quaternion四元數和歐拉角轉換
方 差 分 析
[algorithm] I brushed two big factory interview questions and learned array again with tears in my eyes“
修复一次flutter 无法选中模拟器
UESTC (shenhengtao team) & JD AI (Mei Tao team) proposed a structured dual stream attention network for video Q & A, with performance SOTA! Better than the method based on dual video representation
Database learning notes (sql04)
Lucene构建索引的原理及源代码分析
Tcwind mode setting
算力时代怎么「算」?「算网融合」先发优势很重要!
Internship: term understanding and handwriting interface
Keyword long
软件watchdog和ANR触发memory dump讲解
Analysis of variance