当前位置:网站首页>pytorch 模型保存的完整例子+pytorch 模型保存只保存可训练参数吗?是(+解决方案)
pytorch 模型保存的完整例子+pytorch 模型保存只保存可训练参数吗?是(+解决方案)
2022-07-02 19:02:00 【FakeOccupational】
测试使用的是一个liner model,还有更多的问题。pytorch 模型保存只保存可训练参数吗?
# 导入包
import glob
import os
import torch
import matplotlib.pyplot as plt
import random #用于数据迭代器生成随机数据
# 生成数据集 x1类别0,x2类别1
n_data = torch.ones(50, 2) # 数据的基本形态
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # 类型0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # 类型1 shape=(50, 1)
# 注意 x, y 数据的数据形式一定要像下面一样(torch.cat是合并数据)
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # 数据集可视化 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # 数据读取: def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # 样本的读取顺序是随机的
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) #最后一次可能不足一个batch
yield x.index_select(0, j), y.index_select(0, j)
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
state_dict = {
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
cmp = y_hat.type(y.dtype) > 0.5 # 大于0.5类别1
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # 迭代次数
batch_size = 10 # 批量大小
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
l = loss(model(X), y_train).sum() # l是有关小批量X和y的损失
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
# 导入包
import glob
import os
import torch
import matplotlib.pyplot as plt
import random #用于数据迭代器生成随机数据
# 生成数据集 x1类别0,x2类别1
n_data = torch.ones(50, 2) # 数据的基本形态
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # 类型0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # 类型1 shape=(50, 1)
# 注意 x, y 数据的数据形式一定要像下面一样(torch.cat是合并数据)
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # 数据集可视化 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # 数据读取: def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # 样本的读取顺序是随机的
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) #最后一次可能不足一个batch
yield x.index_select(0, j), y.index_select(0, j)
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
state_dict = {
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
cmp = y_hat.type(y.dtype) > 0.5 # 大于0.5类别1
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # 迭代次数
batch_size = 10 # 批量大小
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
# for epoch in range(num_epochs):
# for X, y_train in data_iter(batch_size, x, y):
# optimizer.zero_grad()
# l = loss(model(X), y_train).sum() # l是有关小批量X和y的损失
# l.backward(retain_graph=True)
# optimizer.step()
# print(l)
# saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
def loader(model_path):
state_dict = torch.load(model_path)
model_state_dict = state_dict["model_state_dict"]
optimizer_state_dict = state_dict["optimizer_state_dict"]
return model_state_dict, optimizer_state_dict
model_state_dict, optimizer_state_dict = loader("h1")
print('pretrained models loaded!')
pytorch 模型保存只保存可训练参数吗?是
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
self.notrain= torch.rand((64, 64), dtype=torch.float)
def forward(self, x):
return self.net(x)
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
# self.notrain = torch.rand((64, 64), dtype=torch.float)
self.notrain = torch.nn.Parameter(torch.ones(64, 64))
def forward(self, x):
return self.net(x)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
l = loss(model(X), y_train).sum() # l是有关小批量X和y的损失
model.notrain.data = model.notrain.data+2
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
- What are the benefits of multi terminal applet development? Covering Baidu applet, Tiktok applet, wechat applet development, and seizing the multi platform traffic dividend
- 450 Shenxin Mianjing 1
- Notes on hardware design of kt148a voice chip IC
- 多端小程序开发有什么好处?覆盖百度小程序抖音小程序微信小程序开发,抢占多平台流量红利
- Design and implementation of ks004 based on SSH address book system
- At compilation environment setup -win
- 蓝牙芯片ble是什么,以及该如何选型,后续技术发展的路径是什么
- Build a master-slave mode cluster redis
- How to avoid duplicate data in gaobingfa?
- 勵志!大凉山小夥全獎直博!論文致謝看哭網友
KS004 基于SSH通讯录系统设计与实现
嵌入式(PLD) 系列,EPF10K50RC240-3N 可编程逻辑器件
Introduction to program ape (XII) -- data storage
Windows2008r2 installing php7.4.30 requires localsystem to start the application pool, otherwise 500 error fastcgi process exits unexpectedly
Automatically generate VGg image annotation file
[daily question] 241 Design priorities for operational expressions
Zabbix5 client installation and configuration
Shardingsphere jdbc5.1.2 about select last_ INSERT_ ID () I found that there was still a routing problem
AcWing 340. Solution to communication line problem (binary + double ended queue BFS for the shortest circuit)
JASMINER X4 1U deep disassembly reveals the secret behind high efficiency and power saving
AcWing 1127. Sweet butter solution (shortest path SPFA)
Chapter 7 - class foundation
台湾SSS鑫创SSS1700替代Cmedia CM6533 24bit 96KHZ USB音频编解码芯片
AcWing 1129. 热浪 题解(最短路—spfa)
[ERP software] what are the dangers of the secondary development of ERP system?
AcWing 1125. 牛的旅行 题解(最短路、直径)
AcWing 340. 通信线路 题解(二分+双端队列BFS求最短路)
解决方案:VS2017 无法打开源文件 stdio.h main.h 等头文件[通俗易懂]
Development skills of rxjs observable custom operator