当前位置:网站首页>Pytorch (network model training)
Pytorch (network model training)
2022-06-26 05:40:00 【Yuetun】
Table of contents title
Network model training
episode
difference
import torch
a=torch.tensor(5)
print(a)
print(a.item())

import torch
output=torch.tensor([[0.1,0.2],[0.05,0.4]])
print(output.argmax(1))# by 1 Select the index of the maximum value of each row , by 0 Select the index of the maximum value of each column
preds=output.argmax(1)
target=torch.tensor([0,1])
print(preds==target)
print((preds==target).sum())

Training models
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flatten
# Building neural networks
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
def forward(self,x):
x=self.model1(x)
return x
if __name__=='__main__':
dun=Dun()
input=torch.ones((64,3,32,32))
print(dun(input).shape)
Data training
import torchvision
# Prepare the dataset
from torch.utils.tensorboard import SummaryWriter
from model import *
from torch.utils.data import DataLoader
train_data=torchvision.datasets.CIFAR10(root="./data_set_train",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10(root="./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
# length
train_data_size=len(train_data)
test_data_size=len(test_data)
print("train_data_size:{}",format(train_data_size))
print("test_data_size:{}",format(test_data_size))
# Load data set
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloder=DataLoader(test_data,batch_size=64)
# Create a network model
dun=Dun()
# Loss function
loss_fn=nn.CrossEntropyLoss()
# Optimizer
learning_rate=1e-2
optimizerr=torch.optim.SGD(dun.parameters(),lr=learning_rate)
# Set training network parameters
# Record the number of workouts
total_train_step=0
# Record the number of tests
total_test_step=0
# Training times
epoch=10
# Additional tensorboard
writer=SummaryWriter("./logs")
for i in range(epoch):
print("---------- The first {} Round training ------".format(i+1))
# Training begins
dun.train()# In the network model , Yes dropout、BatchNorm Layer, etc , Get into training
for data in train_dataloader:
img,target=data
output=dun(img)
loss=loss_fn(output,target)
# Optimizer optimization
optimizerr.zero_grad()
loss.backward()
optimizerr.step()
total_train_step+=1
print(" Training times :{},loss:{}".format(total_train_step,loss.item()))
writer.add_scalar("train_loss",loss.item(),total_train_step)
# testing procedure
total_test_loss=0
# Use the accuracy rate to judge whether the model is good or bad
total_accuracy=0
dun.eval()# In the network model , Yes dropout、BatchNorm Layer, etc , Enter the verification state
with torch.no_grad():
for data in test_dataloder:
img,target=data
output=dun(img)
total_test_loss+=loss_fn(output,target).item()
accuracy=(output.argmax(1)==target).sum()
total_accuracy+=accuracy
print(" On the overall test set Loss:{}".format(total_test_loss))
writer.add_scalar("test_loss",total_test_loss,total_test_step)
print(" Accuracy on the overall test set :{}".format(total_accuracy/test_data_size))
writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)
total_test_step+=1
# Save the model
torch.save(dun,"dun{}.pth".format(i))
print(" Save the model ")
writer.close()
GPU Training
The first way

Call the above three parts cuda Method , Take the code of training data above as an example
# Model
if torch.cuda.is_available():# Decide if you can use gpu
dun=dun.cuda()
# Loss function
if torch.cuda.is_available():
loss_fn=loss_fn.cuda()
# data ( Including training and testing )
if torch.cuda.is_available():
img = img.cuda()
target = target.cuda()
Mode two :
# Define the training equipment
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")# The parameters are divided into cpu and cuda, When there are multiple graphics cards cuda:0
Replace the code of mode 1 with
dun=dun.to(device)
# Other data 、loss similar
see GPU Information

Complete model validation
Look at the dataset CIFAR10 Categories 
import torchvision
from PIL import Image
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flatten
# Building neural networks
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
def forward(self,x):
x=self.model1(x)
return x
image_path="./img/1.png"
image=Image.open(image_path)
print(image)
# Type conversion
transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
image=transform(image)
print(image.shape)
# Load the network model. Note that the loaded model and the verified model can be either used cpu or gpu Agreement , Otherwise, we need map——location Map local cpu
model=torch.load("dun0.pth",map_location=torch.device("cpu"))
print(model)
# Type conversion
image=torch.reshape(image,(1,3,32,32))
model.eval()# Model transformation test type
# Execution model
with torch.no_grad():
output=model(image)
print(output)
print(output.argmax(1))
边栏推荐
- uni-app吸顶固定样式
- one billion two hundred and twelve million three hundred and twelve thousand three hundred and twenty-one
- Win socket programming (Mengxin initial battle)
- 9 common classes
- Supplementary course on basic knowledge of IM development (II): how to design a server-side storage architecture for a large number of image files?
- Feelings of virtual project failure
- Describe an experiment of Kali ARP in LAN
- Chapter 9 setting up structured logging (I)
- skimage. morphology. medial_ axis
- Cyclic displacement
猜你喜欢
随机推荐
Redis usage and memory optimization
How Navicat reuses the current connection information to another computer
There are applications related to web network request API in MATLAB (under update)
Official image acceleration
The model defined (modified) in pytoch loads some required pre training model parameters and freezes them
Recursively traverse directory structure and tree presentation
【C语言】深度剖析数据在内存中的存储
Feelings of virtual project failure
1212312321
Positioning setting horizontal and vertical center (multiple methods)
Customize WebService as a proxy to solve the problem of Silverlight calling WebService across domains
Henkel database custom operator '~~‘
最后一次飞翔
cartographer_ fast_ correlative_ scan_ matcher_ 2D branch and bound rough matching
Internship May 29, 2019
SDN based DDoS attack mitigation
Win socket programming (Mengxin initial battle)
Daily production training report (17)
Thinking about bad money expelling good money
使用Jedis監聽Redis Stream 實現消息隊列功能





![C# 39. Conversion between string type and byte[] type (actual measurement)](/img/33/046aef4e0c1d7c0c0d60c28e707546.png)


