当前位置:网站首页>如何保存训练好的神经网络模型(pytorch版本)
如何保存训练好的神经网络模型(pytorch版本)
2022-07-05 17:01:00 【追光少年羽】
一、保存和加载模型
用数据对模型进行训练后得到了比较理想的模型,但在实际应用的时候不可能每次都先进行训练然后再使用,所以就得先将之前训练好的模型保存下来,然后在需要用到的时候加载一下直接使用。模型的本质是一堆用某种结构存储起来的参数,所以在保存的时候有两种方式,一种方式是直接将整个模型保存下来,之后直接加载整个模型,但这样会比较耗内存;另一种是只保存模型的参数,之后用到的时候再创建一个同样结构的新模型,然后把所保存的参数导入新模型。
二、两种情况的实现方法
(1)只保存模型参数字典(推荐)
#保存
torch.save(the_model.state_dict(), PATH)
#读取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
(2)保存整个模型
#保存
torch.save(the_model, PATH)
#读取
the_model = torch.load(PATH)
三、只保存模型参数的情况(例子)
pytorch会把模型的参数放在一个字典里面,而我们所要做的就是将这个字典保存,然后再调用。
比如说设计一个单层LSTM的网络,然后进行训练,训练完之后将模型的参数字典进行保存,保存为同文件夹下面的rnn.pt文件:
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# 2 for bidirection
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0))
# out: tensor of shape (batch_size, seq_length, hidden_size*2)
out = self.fc(out)
return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
# optimize all cnn parameters
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)
# the target label is not one-hotted
loss_func = nn.MSELoss()
for epoch in range(1000):
output = rnn(train_tensor) # cnn output`
loss = loss_func(output, train_labels_tensor) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
output_sum = output
# 保存模型
torch.save(rnn.state_dict(), 'rnn.pt')
保存完之后利用这个训练完的模型对数据进行处理:
# 测试所保存的模型
m_state_dict = torch.load('rnn.pt')
new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
new_m.load_state_dict(m_state_dict)
predict = new_m(test_tensor)
这里做一下说明,在保存模型的时候rnn.state_dict()表示rnn这个模型的参数字典,在测试所保存的模型时要先将这个参数字典加载一下m_state_dict = torch.load('rnn.pt')
;
然后再实例化一个LSTM对像,这里要保证传入的参数跟实例化rnn是传入的对象时一样的,即结构相同new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
;
下面是给这个新的模型传入之前加载的参数new_m.load_state_dict(m_state_dict)
;
最后就可以利用这个模型处理数据了predict = new_m(test_tensor)
四、保存整个模型的情况(例子)
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # 2 for bidirection
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2)
# print("output_in=", out.shape)
# print("fc_in_shape=", out[:, -1, :].shape)
# Decode the hidden state of the last time step
# out = torch.cat((out[:, 0, :], out[-1, :, :]), axis=0)
# out = self.fc(out[:, -1, :]) # 取最后一列为out
out = self.fc(out)
return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001) # optimize all cnn parameters
loss_func = nn.MSELoss() # the target label is not one-hotted
for epoch in range(1000):
output = rnn(train_tensor) # cnn output`
loss = loss_func(output, train_labels_tensor) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
output_sum = output
# 保存模型
torch.save(rnn, 'rnn1.pt')
保存完之后利用这个训练完的模型对数据进行处理:
new_m = torch.load('rnn1.pt')
predict = new_m(test_tensor)
边栏推荐
- Cartoon: how to multiply large integers? (I) revised version
- Cartoon: how to multiply large integers? (next)
- Cartoon: interesting [pirate] question
- Function sub file writing
- 普通程序员看代码,顶级程序员看趋势
- Cloud security daily 220705: the red hat PHP interpreter has found a vulnerability of executing arbitrary code, which needs to be upgraded as soon as possible
- Is it safe and reliable to open futures accounts on koufu.com? How to distinguish whether the platform is safe?
- How MySQL uses JSON_ Extract() takes JSON value
- 漫画:有趣的【海盗】问题
- Matery主题自定义(一)黑夜模式
猜你喜欢
【性能测试】jmeter+Grafana+influxdb部署实战
Complete solution instance of Oracle shrink table space
Machine learning 02: model evaluation
IDC报告:腾讯云数据库稳居关系型数据库市场TOP 2!
Beijing internal promotion | the machine learning group of Microsoft Research Asia recruits full-time researchers in nlp/ speech synthesis and other directions
一文了解MySQL事务隔离级别
In depth understanding of redis memory obsolescence strategy
腾讯音乐上线新产品“曲易买”,提供音乐商用版权授权
Kafaka技术第一课
ICML 2022 | meta proposes a robust multi-objective Bayesian optimization method to effectively deal with input noise
随机推荐
一个满分的项目文档是如何书写的|得物技术
Embedded-c Language-1
力扣解法汇总1200-最小绝对差
thinkphp3.2.3
Embedded UC (UNIX System Advanced Programming) -1
In depth understanding of redis memory obsolescence strategy
漫画:如何实现大整数相乘?(上) 修订版
Use byte stream to read Chinese from file to console display
普通程序员看代码,顶级程序员看趋势
ICML 2022 | Meta提出魯棒的多目標貝葉斯優化方法,有效應對輸入噪聲
Error in compiling libssh2. OpenSSL cannot be found
Is it safe and reliable to open futures accounts on koufu.com? How to distinguish whether the platform is safe?
Function sub file writing
蚂蚁金服的暴富还未开始,Zoom的神话却仍在继续!
这个17岁的黑客天才,破解了第一代iPhone!
C (WinForm) the current thread is not in a single threaded unit, so ActiveX controls cannot be instantiated
C#(Winform) 当前线程不在单线程单元中,因此无法实例化 ActiveX 控件
漫画:寻找股票买入卖出的最佳时机
Embedded-c Language-2
激动人心!2022开放原子全球开源峰会报名火热开启!