当前位置:网站首页>Understand the construction of the entire network model
Understand the construction of the entire network model
2022-06-28 20:29:00 【seven_ Not Severn】
# Prepare the dataset
# Dataset processing
# Divide the data set
# utilize DataLoader Load data set
eg:dataloder_train = DataLoader(train_data, batch_size=64, drop_last=False)
dataloder_test = DataLoader(test_data, batch_size=64, drop_last=False)
# Building neural networks You can put one separately model.py file ( stay model To test the accuracy of our model .)
import torch
from torch import nn
# Building neural networks
# Be careful model Inside Sequential Refer to the , No more external import 了 , Otherwise, an error will be reported : therefore ==》from model import *
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(in_features=1024, out_features=64),
nn.Linear(in_features=64, out_features=10)
)
def forward(self, x):
x = self.model(x)
return x
# Check for accuracy
if __name__ == '__main__':
test = Test()
input = torch.ones((64, 3, 32, 32))
output = test(input)
print(output.shape)
# Create a network model
test = Test()
# Loss function
loss_fc = nn.CrossEntropyLoss()
# Optimizer
learning_rate = 1e-2
optimizer = torch.optim.SGD(test.parameters(), lr=learning_rate)
## add to tensorboard. hold loss The changing image of
writer = SummaryWriter("logs")
# Understanding of some parameters
step Training times
epoch Number of training rounds
'------------- Train the model and test whether the model is well trained ----------------------'
# Start training :
for i in range(epoch):
print("---------- The first {} Round of training begins :-------------".format(i+1))
# The training steps begin
for data in dataloder_train:
imgs, targets = data
output_train = test(imgs)
loss = loss_fc(output_train, targets) # Gain loss value loss
# Use the optimizer to optimize the model
optimizer.zero_grad() # Gradient clear
loss.backward() # Call loss loss, The gradient of each parameter is obtained
optimizer.step() # Call optimizer optimizer Optimize our parameters
total_train_step = total_train_step + 1 # Record the number of workouts
if total_train_step % 100 == 0: # This reduces the amount of display
# loss.item() And loss Sometimes there are differences ,loss.item() The number returned is
print(" Training times :{}, Loss value :{}".format(total_train_step, loss.item()))
writer.add_scalar("train_loss", loss.item(), total_train_step) # Meet 100 Integer record of
# When the model is training , In order to know if the model is well trained , We do a test after each round of training , The loss value of the test data set is used to evaluate whether the model is well trained , There is no need for the optimizer to perform tuning during testing
# The test step begins
# Loss value
total_test_loss = 0
# Right value
total_accuracy = 0
with torch.no_grad(): # It means that with Code in , Its gradient is gone , No tuning is guaranteed
for data in dataloder_test:
imgs, targets = data
output_test = test(imgs)v
# This loss It's just data Part of the data , Losses in the network model
loss = loss_fc(output_test, targets)
# We require... On the entire data set loss. So here's a total loss
total_test_loss = total_test_loss + loss
accuracy = (output_test.argmax(1) == targets).sum() # Calculate forecast and actual Consistent number ##(1) It means to search for the largest number of rows
total_accuracy = total_accuracy + accuracy # Total correct number
print(" Loss value of the overall test set :{}".format(total_test_loss.item()))
print(" The overall test accuracy is :{}".format(total_accuracy/test_data_size))
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1
# Save the model for each round of training
torch.save(test, "test_{}.pth".format(i)) # Method to save
# Here is method 2 to save the model , Save the parameters as word type
# torch.save(test.state_dict(), "test_{}".format(i))
print(" Model saved ")
# It is explained here that the model is saved once for each round of training , After that, we can take each saved model for training to see the results , In fact, it's different
writer.close()
# draw loss perhaps accuracy The image changes :
It can be used writer = SummaryWriter("logs") perhaps plt.figure Methods
' The first one is : It is directly the numerical value trained by our model '
writer = SummaryWriter("logs")
...
writer.add_scalar("train_loss", loss.item(), total_train_step)
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
' The second kind , And I need to change it '
become list? still numpy What is that?
边栏推荐
猜你喜欢

Bluecmsv1.6 code audit

ThreadLocal principle

2022焊工(初级)特种作业证考试题库及答案

2022 t elevator repair test question bank simulation test platform operation
![[graduation season · advanced technology Er] hard work can only pass, hard work can be excellent!](/img/e5/b6035abfa7d4bb59c3080d3b87ce45.jpg)
[graduation season · advanced technology Er] hard work can only pass, hard work can be excellent!

Leetcode week 299

Day88.七牛云: 房源图片、用户头像上传

方 差 分 析

【Try to Hack】Cobalt Strike(一)

bluecmsv1.6代码审计
随机推荐
How to analyze the relationship between enterprise digital transformation and data asset management?
How to "calculate" in the age of computing power? The first mover advantage of "convergence of computing and networking" is very important!
员工薪资管理系统
head、tail查看文件
数据标准化处理
SQL server2019 create a new SQL server authentication user name and log in
[learning notes] cluster analysis
wc命令的使用
【学习笔记】因子分析
2788.Cifera
SaaS sales upgrade under the new situation | tob Master Course
Windows 64 bit download install my SQL
TcWind 模式设定
Xiaobai's e-commerce business is very important to choose the right mall system!
[learning notes] factor analysis
resilience4j 重试源码分析以及重试指标采集
Rsync remote synchronization
圆球等的相关计算
Past blue bridge cup test questions ants catch cold
Quaternion quaternion and Euler angle Transformation in Ros