当前位置:网站首页>如何保存训练好的神经网络模型(pytorch版本)
如何保存训练好的神经网络模型(pytorch版本)
2022-07-05 17:01:00 【追光少年羽】
一、保存和加载模型
用数据对模型进行训练后得到了比较理想的模型,但在实际应用的时候不可能每次都先进行训练然后再使用,所以就得先将之前训练好的模型保存下来,然后在需要用到的时候加载一下直接使用。模型的本质是一堆用某种结构存储起来的参数,所以在保存的时候有两种方式,一种方式是直接将整个模型保存下来,之后直接加载整个模型,但这样会比较耗内存;另一种是只保存模型的参数,之后用到的时候再创建一个同样结构的新模型,然后把所保存的参数导入新模型。
二、两种情况的实现方法
(1)只保存模型参数字典(推荐)
#保存
torch.save(the_model.state_dict(), PATH)
#读取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))(2)保存整个模型
#保存
torch.save(the_model, PATH)
#读取
the_model = torch.load(PATH)三、只保存模型参数的情况(例子)
pytorch会把模型的参数放在一个字典里面,而我们所要做的就是将这个字典保存,然后再调用。
比如说设计一个单层LSTM的网络,然后进行训练,训练完之后将模型的参数字典进行保存,保存为同文件夹下面的rnn.pt文件:
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# 2 for bidirection
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0))
# out: tensor of shape (batch_size, seq_length, hidden_size*2)
out = self.fc(out)
return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
# optimize all cnn parameters
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)
# the target label is not one-hotted
loss_func = nn.MSELoss()
for epoch in range(1000):
output = rnn(train_tensor) # cnn output`
loss = loss_func(output, train_labels_tensor) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
output_sum = output
# 保存模型
torch.save(rnn.state_dict(), 'rnn.pt')保存完之后利用这个训练完的模型对数据进行处理:
# 测试所保存的模型
m_state_dict = torch.load('rnn.pt')
new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
new_m.load_state_dict(m_state_dict)
predict = new_m(test_tensor)这里做一下说明,在保存模型的时候rnn.state_dict()表示rnn这个模型的参数字典,在测试所保存的模型时要先将这个参数字典加载一下m_state_dict = torch.load('rnn.pt');
然后再实例化一个LSTM对像,这里要保证传入的参数跟实例化rnn是传入的对象时一样的,即结构相同new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device);
下面是给这个新的模型传入之前加载的参数new_m.load_state_dict(m_state_dict);
最后就可以利用这个模型处理数据了predict = new_m(test_tensor)
四、保存整个模型的情况(例子)
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # 2 for bidirection
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2)
# print("output_in=", out.shape)
# print("fc_in_shape=", out[:, -1, :].shape)
# Decode the hidden state of the last time step
# out = torch.cat((out[:, 0, :], out[-1, :, :]), axis=0)
# out = self.fc(out[:, -1, :]) # 取最后一列为out
out = self.fc(out)
return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001) # optimize all cnn parameters
loss_func = nn.MSELoss() # the target label is not one-hotted
for epoch in range(1000):
output = rnn(train_tensor) # cnn output`
loss = loss_func(output, train_labels_tensor) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
output_sum = output
# 保存模型
torch.save(rnn, 'rnn1.pt')保存完之后利用这个训练完的模型对数据进行处理:
new_m = torch.load('rnn1.pt')
predict = new_m(test_tensor)边栏推荐
- Error in compiling libssh2. OpenSSL cannot be found
- c#图文混合,以二进制方式写入数据库
- Judge whether a number is a prime number (prime number)
- CMake教程Step1(基本起点)
- Kafaka technology lesson 1
- 基于Redis实现延时队列的优化方案小结
- 2022年信息系统管理工程师考试大纲
- What else do you not know about new map()
- Is it safe and reliable to open futures accounts on koufu.com? How to distinguish whether the platform is safe?
- 忽米沄析:工业互联网标识解析与企业信息系统的融合应用
猜你喜欢

33:第三章:开发通行证服务:16:使用Redis缓存用户信息;(以减轻数据库的压力)

IDEA 项目启动报错 Shorten the command line via JAR manifest or via a classpath file and rerun.

ternary operator

c#图文混合,以二进制方式写入数据库

thinkphp模板的使用

ICML 2022 | meta proposes a robust multi-objective Bayesian optimization method to effectively deal with input noise
基于Redis实现延时队列的优化方案小结

Embedded-c Language-1

First day of learning C language

ICML 2022 | Meta提出鲁棒的多目标贝叶斯优化方法,有效应对输入噪声
随机推荐
MySql 查询符合条件的最新数据行
How to write a full score project document | acquisition technology
CMake教程Step4(安装和测试)
Use byte stream to read Chinese from file to console display
Embedded UC (UNIX System Advanced Programming) -3
Zhang Ping'an: accélérer l'innovation numérique dans le cloud et construire conjointement un écosystème industriel intelligent
机器学习01:绪论
33:第三章:开发通行证服务:16:使用Redis缓存用户信息;(以减轻数据库的压力)
机器学习02:模型评估
Learn about MySQL transaction isolation level
漫画:一道数学题引发的血案
这个17岁的黑客天才,破解了第一代iPhone!
thinkphp3.2.3
【性能测试】jmeter+Grafana+influxdb部署实战
Embedded-c Language-3
Embedded UC (UNIX System Advanced Programming) -1
Use JDBC technology and MySQL database management system to realize the function of course management, including adding, modifying, querying and deleting course information.
Check the WiFi password connected to your computer
力扣解法汇总729-我的日程安排表 I
BigDecimal除法的精度问题