当前位置:网站首页>Pytorch model tuning - only some layers of the pre training model are loaded

Pytorch model tuning - only some layers of the pre training model are loaded

2022-06-13 08:51:00 Human high quality Algorithm Engineer

Redefine a model structure , In fact, fine-tuning is to remove the original model fc layer , New plus one fc linear layer , That is, only some layers of the pre training model are loaded , Do not load as : Parameters of classification layer

    if opt.continue_model != '':
        print(f'loading pretrained model from {
      opt.continue_model}')
        pretrained_dict = torch.load(opt.continue_model)
        model_dict = model.state_dict()
        pretrained_dict = {
    k: v for k, v in pretrained_dict.items() if (k in model_dict and 'Prediction' not in k)}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

The key is pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and ‘Prediction’ not in k)}, Here's one if Filter parameters , Write filtered according to your own needs if Conditions , My changes are as follows :

backbone_pth = r"J:\code\gate-decorator-pruning\pre_train\face\orbbec_ir.pth".replace('\\', '/')
pretrained_dict = {
    'module.'+k: v for k, v in torch.load(backbone_pth).items() if ('fc' not in k and 'features' not in k)}
# print(pretrained_dict.keys())
model_dict = pack.net.state_dict()
model_dict.update(pretrained_dict)
pack.net.load_state_dict(model_dict)

common model.state_dict() It's a OrderedDict type ,

OrderedDict([('module.conv1.weight', tensor([[[[-1.1209e-13, -1.1880e-13, -4.3911e-14],
          [ 1.0236e-13,  2.1390e-13,  2.1669e-13],
          [ 1.1124e-13,  2.7313e-13,  2.8194e-13]],

         [[-1.1167e-13, -1.1761e-13, -4.2205e-14],
          [ 1.0327e-13,  2.1355e-13,  2.2042e-13],
          [ 1.0807e-13,  2.7697e-13,  2.8198e-13]],

         [[-1.1175e-13, -1.1731e-13, -4.3114e-14],
          [ 1.0148e-13,  2.1327e-13,  2.1844e-13],
          [ 1.1156e-13,  2.7669e-13,  2.8263e-13]]],


        [[[ 8.3619e-24, -2.7362e-23, -5.2695e-23],
          [ 2.8068e-23,  5.4774e-24, -6.4889e-23],
          [-3.7576e-24,  3.1777e-24, -4.6432e-23]],

         [[ 2.2920e-24, -2.1073e-23, -5.5368e-23],
          [ 3.4743e-23,  6.5649e-24, -5.4008e-23],
          [-8.6869e-24, -5.3493e-24, -4.7047e-23]],

         [[ 1.0901e-23, -2.7621e-23, -6.0086e-23],
          [ 3.6049e-23,  1.0260e-23, -6.2163e-23],
          [-5.4703e-24, -4.8527e-25, -4.3399e-23]]],


        [[[-9.7617e-03,  5.3041e-03, -1.0028e-02],
          [ 2.2801e-03,  3.3919e-02,  1.9303e-02],
          [-7.5236e-03,  2.8626e-02,  5.3559e-03]],

The weight name and the corresponding value are returned as tuples ,

import collections
 
dic = collections.OrderedDict()
dic['k1'] = 'v1'
dic['k2'] = 'v2'
print(dic.items())
 
# Output :odict_items([('k1', 'v1'), ('k2', 'v2')])
原网站

版权声明
本文为[Human high quality Algorithm Engineer]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/02/202202270536289588.html