当前位置:网站首页>Complete example of pytorch model saving +does pytorch model saving only save trainable parameters? Yes (+ solution)
Complete example of pytorch model saving +does pytorch model saving only save trainable parameters? Yes (+ solution)
2022-07-02 19:55:00 【FakeOccupational】
The test uses a liner model, There are more questions .pytorch Does model saving only save trainable parameters ?
save Model
# Import package
import glob
import os
import torch
import matplotlib.pyplot as plt
import random # For data iterators to generate random data
# Generate data set x1 Category 0,x2 Category 1
n_data = torch.ones(50, 2) # The basic form of data
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # type 0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # type 1 shape=(50, 1)
# Be careful x, y The data form of data must be like the following (torch.cat Is consolidated data )
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # Dataset Visualization plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # data fetch : def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # The reading order of samples is random
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # The last time may be less than one batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
""" Calculate the correct number of predictions ."""
cmp = y_hat.type(y.dtype) > 0.5 # Greater than 0.5 Category 1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # The number of iterations
batch_size = 10 # Batch size
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l It's about small batches X and y The loss of
l.backward(retain_graph=True)
optimizer.step()
print(l)
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
load Model
# Import package
import glob
import os
import torch
import matplotlib.pyplot as plt
import random # For data iterators to generate random data
# Generate data set x1 Category 0,x2 Category 1
n_data = torch.ones(50, 2) # The basic form of data
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # type 0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # type 1 shape=(50, 1)
# Be careful x, y The data form of data must be like the following (torch.cat Is consolidated data )
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # Dataset Visualization plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # data fetch : def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # The reading order of samples is random
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # The last time may be less than one batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
""" Calculate the correct number of predictions ."""
cmp = y_hat.type(y.dtype) > 0.5 # Greater than 0.5 Category 1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # The number of iterations
batch_size = 10 # Batch size
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
# for epoch in range(num_epochs):
# for X, y_train in data_iter(batch_size, x, y):
# optimizer.zero_grad()
# l = loss(model(X), y_train).sum() # l It's about small batches X and y The loss of
# l.backward(retain_graph=True)
# optimizer.step()
# print(l)
# saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
def loader(model_path):
state_dict = torch.load(model_path)
model_state_dict = state_dict["model_state_dict"]
optimizer_state_dict = state_dict["optimizer_state_dict"]
return model_state_dict, optimizer_state_dict
model_state_dict, optimizer_state_dict = loader("h1")
model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
print('pretrained models loaded!')
pytorch Does model saving only save trainable parameters ? yes
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
self.notrain= torch.rand((64, 64), dtype=torch.float)
def forward(self, x):
return self.net(x)

Solution
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
# self.notrain = torch.rand((64, 64), dtype=torch.float)
self.notrain = torch.nn.Parameter(torch.ones(64, 64))
def forward(self, x):
return self.net(x)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l It's about small batches X and y The loss of
l.backward(retain_graph=True)
optimizer.step()
print(l)
model.notrain.data = model.notrain.data+2
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
Reference and more
PyTorch DataLoader Of bug : Random mask Or random selection of data bug
边栏推荐
- sql-labs
- JASMINER X4 1U deep disassembly reveals the secret behind high efficiency and power saving
- 【NLP】一文详解生成式文本摘要经典论文Pointer-Generator
- 八年测开经验,面试28K公司后,吐血整理出高频面试题和答案
- 数据库模式笔记 --- 如何在开发中选择合适的数据库+关系型数据库是谁发明的?
- Build a master-slave mode cluster redis
- 从20s优化到500ms,我用了这三招
- Conscience summary! Jupyter notebook from Xiaobai to master, the nanny tutorial is coming!
- 嵌入式(PLD) 系列,EPF10K50RC240-3N 可编程逻辑器件
- KT148A语音芯片ic的用户端自己更换语音的方法,上位机
猜你喜欢

AcWing 340. Solution to communication line problem (binary + double ended queue BFS for the shortest circuit)

What is the Bluetooth chip ble, how to select it, and what is the path of subsequent technology development

Py's interpret: a detailed introduction to interpret, installation, and case application

RPD出品:Superpower Squad 保姆级攻略

Refactoring: improving the design of existing code (Part 2)

Conscience summary! Jupyter notebook from Xiaobai to master, the nanny tutorial is coming!

HDL design peripheral tools to reduce errors and help you take off!

SQLite 3.39.0 发布,支持右外连接和全外连接

AcWing 340. 通信线路 题解(二分+双端队列BFS求最短路)

Notes on hardware design of kt148a voice chip IC
随机推荐
Implementation of online shopping mall system based on SSM
SQLite 3.39.0 release supports right external connection and all external connection
Google Earth Engine(GEE)——Landsat 9影像全波段影像下载(北京市为例)
[daily question] 241 Design priorities for operational expressions
[ERP software] what are the dangers of the secondary development of ERP system?
Pytorch版本、CUDA版本与显卡驱动版本的对应关系
c语言里怎么设立优先级,细说C语言优先级
Istio1.12:安装和快速入门
Automatically generate VGg image annotation file
[译]深入了解现代web浏览器(一)
勵志!大凉山小夥全獎直博!論文致謝看哭網友
面试经验总结,为你的offer保驾护航,满满的知识点
浏览器缓存机制概述
Istio deployment: quickly start microservices,
pxe装机「建议收藏」
Development skills of rxjs observable custom operator
R language uses econcharts package to create microeconomic or macroeconomic maps, and indifference function to visualize indifference curve
Use IDM to download Baidu online disk files (useful for personal testing) [easy to understand]
Implementation of 453 ATOI function
数据湖(十二):Spark3.1.2与Iceberg0.12.1整合