当前位置:网站首页>Pytorch: sub model parameter freezing + BN freezing

Pytorch: sub model parameter freezing + BN freezing

2022-06-10 04:46:00 CV research Capriccio

Use scenarios : Need to completely freeze a part of weight And BN layer
   When loading the pre training model , If only para.requires_grad = False , The parameters of the model cannot be completely frozen , Because of the BN Layer does not follow loss.backward() And optimizer.step() To update , But in the model forward Is based on momentum , So you need every forward Before freezing BN layer :
The complete freezing method is as follows :


'''  A pile of code  '''

#  frozen BN
def freeze_bn(m):
    classname = ly.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()


'''  A pile of code  '''
freeze_state_dict = torch.load(opt.loadckpt_freeze)
frozen_list = [k for k, _ in freeze_state_dict['state_dict'].items() if k in model_dict]
#  Freeze first except  BN  Parameters other than 
for param in model.named_parameters():
    if param[0] in frozen_list:		#  Parameter list to be frozen 
        param[1].requires_grad = False

#  The parameters optimized by the optimizer only contain the parameters that need gradient update  
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr, betas=(0.9,0.999))

'''  A pile of code  '''

for epoch in range(opt.epoch):
	model.train()
	optimizer.zero_grad()
	#  frozen BN
	model.apply(freeze_bn)
	#  Forward propagation 
	output = model(input)
	loss = loss_F(gt, output)
	loss.backward()
	optimizer.step()
原网站

版权声明
本文为[CV research Capriccio]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/160/202206091226331563.html