当前位置:网站首页>Hands on deep learning -- weight decay and code implementation
Hands on deep learning -- weight decay and code implementation
2022-06-12 08:14:00 【Orange acridine 21】
One 、 Weight decline
1、 Weight decline : It is a method commonly used to deal with over fitting .
2、 Use the mean square norm as a hard limit
The model capacity is controlled by limiting the selection range of parameter values

- Generally, the offset is not limited b( Limit or not ).
- Small
Means stronger regular terms .
3、 Use the mean square norm as the flexibility limit
For each
, Can be found λ Make the previous objective function equivalent to the following :

It can be proved by Lagrange multiplier .
Hyperparameters λ Controls the importance of regular terms :

4、 Parameter update rule

Every time you introduce λ Will reduce the weight , So it's called weight decline .
5、 summary
Weight decay through L2 The regular term makes the model parameters not too large , So as to control the complexity of the model .
Regular term weight is a super parameter that controls the complexity of the model .
Two 、 Code implementation

deviation 0.05+ The weight 0.01 Multiply by random input x then + The noise , The mean for 0, The variance of 0.01 Is a normal distribution .
import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
import sys
from matplotlib import pyplot as plt
"""
Generate a manual data set , It is also a linear regression problem ,
deviation 0.05+ The weight 0.01 Multiply by random input x then + The noise , The mean for 0, The variance of 0.01 Is a normal distribution
"""
n_train, n_test, num_inputs = 20, 100, 200
# The smaller the training data set , The easier it is to overfit . The training data set is 20, The test data set is 100, Latitude selection of features 200.
# The smaller the data , The simpler the model , Over fitting is more likely to occur
true_w, true_b = torch.ones(num_inputs, 1) * 0.01, 0.05
# The real weight is 0.01* whole 1 A vector of , deviation b by 0.05
"""
Read a manual data set
"""
features = torch.randn((n_train + n_test, num_inputs)) # features
labels = torch.matmul(features, true_w) + true_b # Sample size
labels += torch.tensor(np.random.normal(0, 0.01,size=labels.size()), dtype=torch.float)
train_features, test_features = features[:n_train, :],features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]
"""
Initialize model , This function attaches a gradient to each parameter
"""
def init_params():
w=torch.randn((num_inputs,1),requires_grad=True)
b=torch.zeros(1,requires_grad=True)
return [w,b]
"""
Definition L2 Norm penalty , Only the weight parameters of the penalty model
"""
def l2_penalty(w):
return (w**2).sum()/2
"""
Define training and test models
"""
batch_size, num_epochs, lr = 1, 100, 0.003
net, loss = d2l.linreg, d2l.squared_loss
dataset = torch.utils.data.TensorDataset(train_features,train_labels)
train_iter = torch.utils.data.DataLoader(dataset, batch_size,shuffle=True)
def fit_and_plot(lambd):
w,b=init_params()
train_ls,test_ls=[],[]
for _ in range(num_epochs):
for X,y in train_iter:
# Added L2 Norm penalty term
l=loss(net(X,w,b),y)+lambd*l2_penalty(w)
l=l.sum()
if w.grad is not None:
w.grad.data.zero_()
b.grad.data.zero_()
l.backward()
d2l.sgd([w, b], lr, batch_size)
train_ls.append(loss(net(train_features, w, b),train_labels).mean().item())
test_ls.append(loss(net(test_features, w, b),test_labels).mean().item())
d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs','loss',
range(1, num_epochs + 1), test_ls, ['train','test'])
print('L2 norm of w:', w.norm().item())When lambd Set to 0 when , We didn't make it ⽤ Weight falloff . Results the training error is far ⼩ Error on test set . This is a typical over fitting phenomenon .
fit_and_plot(lambd=0)
plt.show()

This is where ⽤ Weight falloff . It can be seen that , Although the training error has been mentioned ⾼, But the error on the test set has decreased . Over fitting results in ⼀ A definite degree of relief . in addition , Weight parameter L2 norm ⽐ Don't make ⽤ power ᯿ Change in attenuation ⼩, At this time, the weight parameter is closer to 0
fit_and_plot(lambd=3)
plt.show()
边栏推荐
- Detailed explanation of Google open source sfmlearner paper combining in-depth learning slam -unsupervised learning of depth and ego motion from video
- Leetcode notes: biweekly contest 69
- Vscode的Katex问题:ParseError: KaTeX Parse Error: Can‘t Use Function ‘$‘ In Math Mode At Position ...
- You get download the installation and use of artifact
- Convolutional neural network CNN based cat dog battle picture classification (tf2.1 py3.6)
- 计组第一章
- APS软件有哪些排程规则?有何异常处理方案?
- 2.2 linked list - Design linked list (leetcode 707)
- Leetcode notes: Weekly contest 275
- Final review of Discrete Mathematics (predicate logic, set, relation, function, graph, Euler graph and Hamiltonian graph)
猜你喜欢

Py&GO编程技巧篇:逻辑控制避免if else

MATLAB image processing -- image transformation correction second-order fitting

Discrete chapter I

(P17-P18)通过using定义基础类型和函数指针别名,使用using和typedef给模板定义别名

(p36-p39) right value and right value reference, role and use of right value reference, derivation of undetermined reference type, and transfer of right value reference

Bean的作用域

企业为什么要实施MES?具体操作流程有哪些?

MATLAB image processing - cosine noise removal in image (with code)

Record the treading pit of grain Mall (I)

802.11 protocol: wireless LAN protocol
随机推荐
Data visualization and Matplotlib
(P21-P24)统一的数据初始化方式:列表初始化、使用初始化列表初始化非聚合类型的对象、initializer_lisy模板类的使用
只把MES当做工具?看来你错过了最重要的东西
Learning notes (1): live broadcast by Dr. Lu Qi - face up to challenges and grasp entrepreneurial innovation opportunities - face up to challenges and grasp entrepreneurial innovation opportunities -1
(P14)overrid关键字的使用
DUF:Deep Video Super-Resolution Network Using Dynamic Upsampling Filters ...阅读笔记
vscode 下载慢解决办法
离散 第一章
Installation series of ROS system (I): installation steps
ctfshow web3
(p36-p39) right value and right value reference, role and use of right value reference, derivation of undetermined reference type, and transfer of right value reference
Vins technical route and code explanation
APS软件有哪些排程规则?有何异常处理方案?
Leetcode notes: Weekly contest 276
MATLAB image processing - Otsu threshold segmentation (with code)
Quaternion Hanmilton and JPL conventions
EasyExcel导出Excel表格到浏览器,并通过Postman测试导出Excel【入门案例】
Explanation and explanation on the situation that the volume GPU util (GPU utilization) is very low and the memory ueage (memory occupation) is very high during the training of pytoch
HDLC protocol
Cookies and sessions