当前位置:网站首页>pytorch 模型保存的完整例子+pytorch 模型保存只保存可训练参数吗?是(+解决方案)
pytorch 模型保存的完整例子+pytorch 模型保存只保存可训练参数吗?是(+解决方案)
2022-07-02 19:02:00 【FakeOccupational】
测试使用的是一个liner model,还有更多的问题。pytorch 模型保存只保存可训练参数吗?
save模型
# 导入包
import glob
import os
import torch
import matplotlib.pyplot as plt
import random #用于数据迭代器生成随机数据
# 生成数据集 x1类别0,x2类别1
n_data = torch.ones(50, 2) # 数据的基本形态
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # 类型0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # 类型1 shape=(50, 1)
# 注意 x, y 数据的数据形式一定要像下面一样(torch.cat是合并数据)
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # 数据集可视化 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # 数据读取: def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # 样本的读取顺序是随机的
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) #最后一次可能不足一个batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
"""计算预测正确的数量。"""
cmp = y_hat.type(y.dtype) > 0.5 # 大于0.5类别1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # 迭代次数
batch_size = 10 # 批量大小
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l是有关小批量X和y的损失
l.backward(retain_graph=True)
optimizer.step()
print(l)
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
load模型
# 导入包
import glob
import os
import torch
import matplotlib.pyplot as plt
import random #用于数据迭代器生成随机数据
# 生成数据集 x1类别0,x2类别1
n_data = torch.ones(50, 2) # 数据的基本形态
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # 类型0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # 类型1 shape=(50, 1)
# 注意 x, y 数据的数据形式一定要像下面一样(torch.cat是合并数据)
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # 数据集可视化 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # 数据读取: def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # 样本的读取顺序是随机的
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) #最后一次可能不足一个batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
"""计算预测正确的数量。"""
cmp = y_hat.type(y.dtype) > 0.5 # 大于0.5类别1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # 迭代次数
batch_size = 10 # 批量大小
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
# for epoch in range(num_epochs):
# for X, y_train in data_iter(batch_size, x, y):
# optimizer.zero_grad()
# l = loss(model(X), y_train).sum() # l是有关小批量X和y的损失
# l.backward(retain_graph=True)
# optimizer.step()
# print(l)
# saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
def loader(model_path):
state_dict = torch.load(model_path)
model_state_dict = state_dict["model_state_dict"]
optimizer_state_dict = state_dict["optimizer_state_dict"]
return model_state_dict, optimizer_state_dict
model_state_dict, optimizer_state_dict = loader("h1")
model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
print('pretrained models loaded!')
pytorch 模型保存只保存可训练参数吗?是
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
self.notrain= torch.rand((64, 64), dtype=torch.float)
def forward(self, x):
return self.net(x)
解决方案
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
# self.notrain = torch.rand((64, 64), dtype=torch.float)
self.notrain = torch.nn.Parameter(torch.ones(64, 64))
def forward(self, x):
return self.net(x)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l是有关小批量X和y的损失
l.backward(retain_graph=True)
optimizer.step()
print(l)
model.notrain.data = model.notrain.data+2
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
参考与更多
边栏推荐
- 【JS】获取hash模式下URL的搜索参数
- [daily question] 241 Design priorities for operational expressions
- AcWing 181. Turnaround game solution (search ida* search)
- At compilation environment setup -win
- Chapter 7 - class foundation
- 【NLP】一文详解生成式文本摘要经典论文Pointer-Generator
- Cuckoo filter
- [译]深入了解现代web浏览器(一)
- What are the benefits of multi terminal applet development? Covering Baidu applet, Tiktok applet, wechat applet development, and seizing the multi platform traffic dividend
- 自动生成VGG图像注释文件
猜你喜欢
B端电商-订单逆向流程
SQLite 3.39.0 release supports right external connection and all external connection
数据湖(十二):Spark3.1.2与Iceberg0.12.1整合
Set up sentinel mode. Reids and redis leave the sentinel cluster from the node
八年测开经验,面试28K公司后,吐血整理出高频面试题和答案
KS004 基于SSH通讯录系统设计与实现
Self-Improvement! Daliangshan boys all award Zhibo! Thank you for your paper
浏览器缓存机制概述
Design and implementation of ks004 based on SSH address book system
编写完10万行代码,我发了篇长文吐槽Rust
随机推荐
JS how to get integer
Educational codeforces round 129 (rated for Div. 2) supplementary problem solution
Yes, that's it!
rxjs Observable 自定义 Operator 的开发技巧
【Hot100】22. bracket-generating
Kt148a voice chip IC user end self replacement voice method, upper computer
为什么我对流程情有独钟?
Cuckoo filter
Common problems and description of kt148a voice chip IC development
Refactoring: improving the design of existing code (Part 2)
450 Shenxin Mianjing 1
Infix expression is converted to suffix expression (C language code + detailed explanation)
upload-labs
Postman download and installation
AcWing 1126. 最小花费 题解(最短路—dijkstra)
[NLP] a detailed generative text Abstract classic paper pointer generator
AcWing 341. Optimal trade solution (shortest path, DP)
Shardingsphere jdbc5.1.2 about select last_ INSERT_ ID () I found that there was still a routing problem
Data Lake (XII): integration of spark3.1.2 and iceberg0.12.1
AcWing 340. Solution to communication line problem (binary + double ended queue BFS for the shortest circuit)