当前位置:网站首页>Hands on deep learning -- weight decay and code implementation

Hands on deep learning -- weight decay and code implementation

2022-06-12 08:14:00 Orange acridine 21

One 、 Weight decline

1、 Weight decline : It is a method commonly used to deal with over fitting .

2、 Use the mean square norm as a hard limit

The model capacity is controlled by limiting the selection range of parameter values

  • Generally, the offset is not limited b( Limit or not ).
  • Small Means stronger regular terms .

3、 Use the mean square norm as the flexibility limit

For each , Can be found λ Make the previous objective function equivalent to the following :

It can be proved by Lagrange multiplier .

Hyperparameters λ Controls the importance of regular terms :

4、 Parameter update rule

Every time you introduce λ Will reduce the weight , So it's called weight decline .

5、 summary

Weight decay through L2 The regular term makes the model parameters not too large , So as to control the complexity of the model .

Regular term weight is a super parameter that controls the complexity of the model .

Two 、 Code implementation

Next ⾯, We use ⾼ Dimensional linear regression as an example ⼊⼀ An over fitting problem , And make ⽤ power ᯿ Attenuation to cope with over fitting . Set data sample characteristics The dimension of the sign is p . For training data sets and test data sets, the characteristics are x1,x2,....xp Appointment ⼀ sample , We make ⽤ The following linearity Function to ⽣ Label the sample :

 

  deviation 0.05+ The weight 0.01 Multiply by random input x then + The noise , The mean for 0, The variance of 0.01 Is a normal distribution .

import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
import sys
from matplotlib import pyplot as plt
"""
 Generate a manual data set , It is also a linear regression problem ,
 deviation 0.05+ The weight 0.01 Multiply by random input x then + The noise , The mean for 0, The variance of 0.01 Is a normal distribution 
"""
n_train, n_test, num_inputs = 20, 100, 200
# The smaller the training data set , The easier it is to overfit . The training data set is 20, The test data set is 100, Latitude selection of features 200.
# The smaller the data , The simpler the model , Over fitting is more likely to occur 

true_w, true_b = torch.ones(num_inputs, 1) * 0.01, 0.05
# The real weight is 0.01* whole 1 A vector of , deviation b by 0.05

"""
 Read a manual data set 
"""
features = torch.randn((n_train + n_test, num_inputs)) # features 
labels = torch.matmul(features, true_w) + true_b # Sample size 
labels += torch.tensor(np.random.normal(0, 0.01,size=labels.size()), dtype=torch.float)
train_features, test_features = features[:n_train, :],features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]

"""
 Initialize model , This function attaches a gradient to each parameter 
"""
def init_params():
    w=torch.randn((num_inputs,1),requires_grad=True)
    b=torch.zeros(1,requires_grad=True)
    return [w,b]

"""
 Definition L2 Norm penalty , Only the weight parameters of the penalty model 
"""
def l2_penalty(w):
    return (w**2).sum()/2

"""
 Define training and test models 
"""
batch_size, num_epochs, lr = 1, 100, 0.003
net, loss = d2l.linreg, d2l.squared_loss
dataset = torch.utils.data.TensorDataset(train_features,train_labels)
train_iter = torch.utils.data.DataLoader(dataset, batch_size,shuffle=True)

def fit_and_plot(lambd):
    w,b=init_params()
    train_ls,test_ls=[],[]
    for _ in range(num_epochs):
        for X,y in train_iter:
    #  Added L2 Norm penalty term 
            l=loss(net(X,w,b),y)+lambd*l2_penalty(w)
            l=l.sum()
            if w.grad is not None:
                w.grad.data.zero_()
                b.grad.data.zero_()
            l.backward()
            d2l.sgd([w, b], lr, batch_size)
        train_ls.append(loss(net(train_features, w, b),train_labels).mean().item())
        test_ls.append(loss(net(test_features, w, b),test_labels).mean().item())
    d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs','loss',
                     range(1, num_epochs + 1), test_ls, ['train','test'])
    print('L2 norm of w:', w.norm().item())
 When  lambd  Set to 0 when , We didn't make it ⽤ Weight falloff .
 Results the training error is far ⼩ Error on test set . This is a typical over fitting phenomenon .
fit_and_plot(lambd=0)
plt.show()

 

 

 This is where ⽤ Weight falloff . It can be seen that , Although the training error has been mentioned ⾼, But the error on the test set has decreased . Over fitting results in ⼀ A definite degree of relief .
 in addition , Weight parameter L2  norm ⽐ Don't make ⽤ power ᯿ Change in attenuation ⼩, At this time, the weight parameter is closer to 0
fit_and_plot(lambd=3)
plt.show()

 

 

 

 

原网站

版权声明
本文为[Orange acridine 21]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/03/202203010550043906.html