当前位置:网站首页>Pytorch中自己所定义(修改)的模型加载所需部分预训练模型参数并冻结
Pytorch中自己所定义(修改)的模型加载所需部分预训练模型参数并冻结
2022-06-26 05:26:00 【被rua弄的小狸花】
本文部分参考https://zhuanlan.zhihu.com/p/34147880
一.此法比较万能,就根据自己模型的参数来加载预训练模型参数,同名就赋值。如果自己在原模型上加了些层则不会加载
dict_trained=torch.load(self.args.load_path, map_location=torch.device('cpu'))
dict_new=model.state_dict()
# 1. filter out unnecessary keys
dict_trained = {
k: v for k, v in dict_trained.items() if k in dict_new}
# 2. overwrite entries in the existing state dict
model_dict.update(dict_trained)
model.load_state_dict(dict_new)
二. 这个则就复杂不少,按自己所需进行更改,比如我的,就是本模型增加了四层’dense’, ‘unary_affine’, ‘binary_affine’, ‘classifier’,通过j+=8,跳过他们的weight和bias,这个可以参考权重衰减。同时将原模型参数中’crf’部分不加载。
dict_trained = torch.load(self.args.load_path, map_location=torch.device('cpu'))
dict_new = self.model.state_dict().copy()
trained_list = list(dict_trained.keys())
new_list = list(dict_new.keys())
j = 0
no_loda = {'dense', 'unary_affine', 'binary_affine', 'classifier'}
for i in range(len(trained_list)):
flag = False
if 'crf' in trained_list[i]:
continue
for nd in no_loda:
if nd in new_list[j] and 'bert' not in new_list[j]:
flag = True
if flag:
j += 8 # no_loda的dense和bias掠过
else:
dict_new[new_list[j]] = dict_trained[trained_list[i]]
if new_list[j] != trained_list[i]:
print("i:{},new_state_dict: {} trained state_dict: {}不一致".format(i, new_list[j], trained_list[i]))
j += 1 #keys不对齐
model.load_state_dict(dict_new)
后面了解到有一种更简单的方法:
就是当你设置好你自己的模型后,如果仅想使用预训练模型相同结构处的参数,即在加载的时候将参数strict设置为False即可。该参数值默认为True,表示预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度),否则无法加载,实现如下:
model.load_state_dict(torch.load(self.args.load_path, strict=False))
PS: 遇到错了,不妨把自己所修改模型参数的keys和加载模型参数的keys打印下来看看,对症下药
三.冻结这几层参数
简单来说就是
for k in model.paramers:
k.requires_grad=False
方法很多,这里用和上面方法对应的冻结方法
建议看一下
https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088
或者
https://discuss.pytorch.org/t/correct-way-to-freeze-layers/26714
或者
对应的,在训练时候,optimizer里面只能更新requires_grad = True的参数,于是
optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, net.parameters(),lr) )
边栏推荐
- Serious hazard warning! Log4j execution vulnerability is exposed!
- cartographer_ backend_ constraint
- LSTM in tensorflow_ Layers actual combat
- Daily production training report (17)
- 线程优先级
- Redis installation on Linux
- Daily production training report (16)
- ECCV 2020 double champion team, take you to conquer target detection on the 7th
- RESNET practice in tensorflow
- 【ARM】在NUC977上搭建基于boa的嵌入式web服务器
猜你喜欢

Mongodb image configuration method

The State Council issued a document to improve the application of identity authentication and electronic seals, and strengthen the construction of Digital Government

Experience of reading the road to wealth and freedom

uniCloud云开发获取小程序用户openid

【上采样方式-OpenCV插值】

【活动推荐】云原生、产业互联网、低代码、Web3、元宇宙……哪个是 2022 年架构热点?...

Baidu API map is not displayed in the middle, but in the upper left corner. What's the matter? Resolved!

Sofa weekly | open source person - Yu Yu, QA this week, contributor this week

Beidou navigation technology and industrial application of "chasing dreams in space and feeling for Beidou"

uni-app吸顶固定样式
随机推荐
Redis usage and memory optimization
Use jedis to monitor redis stream to realize message queue function
cartographer_pose_graph_2d
ZigBee learning in simple terms Lecture 1
Sofa weekly | open source person - Yu Yu, QA this week, contributor this week
[activity recommendation] cloud native, industrial Internet, low code, Web3, metauniverse... Which is the architecture hot spot in 2022
The parameter field of the callback address of the payment interface is "notify_url", and an error occurs after encoding and decoding the signed special character URL (,,,,,)
PHP 2D / multidimensional arrays are sorted in ascending and descending order according to the specified key values
Setting pseudo static under fastadmin Apache
Uni app ceiling fixed style
Internship May 29, 2019
LSTM in tensorflow_ Layers actual combat
Apktool tool usage document
How to rewrite a pseudo static URL created by zenpart
Beidou navigation technology and industrial application of "chasing dreams in space and feeling for Beidou"
Introduction to GUI programming to game practice (I)
Keras actual combat cifar10 in tensorflow
MySQL source code reading (II) login connection debugging
【上采样方式-OpenCV插值】
瀚高数据库自定义操作符‘!~~‘