当前位置:网站首页>Missing getting in online continuous learning with neuron calibration thesis analysis + code reading
Missing getting in online continuous learning with neuron calibration thesis analysis + code reading
2022-06-12 07:18:00 【Programmer long】
The address of the paper is here
This is an article on Baidu research
One . Introduce
In response to catastrophic forgetting , This article focuses on replay based approaches ( For example, as we said before GEM and MEGA Two articles ). Allow the model to have limited access to data from past tasks , So as to rehearse the past experience . However, the playback based method is easy to lead to data imbalance , That is, stability - Plasticity dilemma . One side , The model may be affected by the past knowledge so that it cannot learn new knowledge quickly , On the other hand , The knowledge of the past may fade away in learning .
In this paper , The author solves this problem from a new angle , That is to seek the balance between stability and plasticity through neuronal calibration . Specifically, neuron calibration refers to the process of mathematical adjustment of the transformation function of each layer of the deep neural network . The purpose of neuron calibration is to regularize parameter updating by setting a trainable soft mask to prevent catastrophic forgetting , Then the forward reasoning path and the reverse optimization path affect the model reasoning process and the model training process . in other words , This paper trains a shared calibration model , Interweave data from different task distributions , So as to effectively optimize the model , Instead of keeping the parameters of a specific task, save the task knowledge to prevent forgetting .
Two . Related work
Deal with catastrophic forgetting according to existing methods , At present, it is mainly divided into three categories , as follows .
Based on the way of episodic memory playback : Store part of the past data in episodic memory , For future knowledge rehearsal . Memory based approach can better solve catastrophic forgetting , But if memory and actual conditions are limited , It's easy to be disturbed .
Based on the way of regularization playback : By extending the loss function in continuous learning , To facilitate selective consolidation of past knowledge stored in model parameters . This approach uses trade-off parameter information , Identify parameters that are more important to past tasks , To avoid forgetting .
Dynamic architecture : The catastrophic forgetting problem is solved by approximately training a separate network for each task .
3、 ... and . NCCL( Neuron calibration for online continuous learning )
Symbol definition
{ T 1 , . . . T T } : \{\mathcal{T_1,...}\mathcal{T}_T\}: { T1,...TT}: Represents an online continuous learning task sequence . Each task is given a small amount of storage space to store past data .
M t \mathcal{M}_t Mt: It means from training to t A task , Saved about t Partial data of tasks .
{ θ i } i = 1 L : \{\theta_i\}^L_{i=1}: { θi}i=1L: share L Layer neural networks , Parameters of each layer
3.1 Neuron calibration
Calibrate by applying neurons , The goal is to adapt to the transformation function in the deep neural network layer , So as to effectively mitigate the catastrophic changes of model parameters , Achieve a stable range of knowledge from different tasks . say concretely , In this paper, two commonly used layers are transformed : Full connection layer and convolution layer . Here, the author gives a diagram to illustrate how to work .
In the figure, two ways of transformation are mentioned . The first is the weight calibration module (WCM), The second is the feature calibration module (FCM). The weight calibration module learns the weights of the parameters in the scaling transform function , The feature calibration module learns the output feature mapping of scaling transform function prediction . For the sake of illustration , use θ i \theta_i θi Express WCM Parameters before , θ ~ i \tilde{\theta}_i θ~i Pass through WCM Parameters of . Empathy h i , h ~ i h_i,\tilde{h}_i hi,h~i Express FCM Mapping of output features before and after .
WCM
set up Ω ψ i ( . ) \Omega_{\psi_i}(.) Ωψi(.) Indicates that the weight calibration function is deployed in the i Layer network , Its parameter is ψ i \psi_i ψi. Weight calibration is expected to be modular , Using cell multiplication , Applied between basic network parameters and calibration parameters . Specifically as follows :
Ω ψ i ( θ i ) = { t i l e ( ψ i ) ⊙ θ i ψ i ∈ R O ∗ I (Convolution Layer) t i l e ( ψ i ) ⊙ θ i ψ i ∈ R O (Fully Connected Layer) (1) \Omega_{\psi_i}(\theta_i)=\begin{dcases} tile(\psi_i)\odot \theta_i & \psi_i\in\mathbb{R}^{O*I} \ \ \ \text{(Convolution Layer)}\\ tile(\psi_i)\odot \theta_i &\psi_i\in\mathbb{R}^{O} \ \ \ \ \ \ \text{(Fully Connected Layer)} \tag{1} \end{dcases} Ωψi(θi)={ tile(ψi)⊙θitile(ψi)⊙θiψi∈RO∗I (Convolution Layer)ψi∈RO (Fully Connected Layer)(1)
among O and I Indicates the number of channels for output and input . To reduce calibration parameters ψ i \psi_i ψi, Its size ratio θ i \theta_i θi Many small , Therefore use tile function ( A repeating placement function , You can search on Baidu ) To expand . Weight calibration method , Riding a crucial role : In forward propagation , It scales the base network parameter values for prediction . In the back propagation optimization process , It regularizes the update of important parameters as a priority weight ( ∇ θ 1 L b \nabla_{\theta_1}\mathcal{L_b} ∇θ1Lb In order to ∇ θ i ~ L b ⊙ t i l e ( ψ 1 ) \nabla_{\tilde{\theta_i}}\mathcal{L_b}\odot tile(\psi_1) ∇θi~Lb⊙tile(ψ1) For export , Is the scaling with calibrator parameters .)
After the calibration of the weight , Our first i The output of the layer is :
h i = F θ ~ i ( h ~ i − 1 ) s . t θ ~ i = Ω ψ i ( θ i ) (2) h_i = \mathcal{F}_{\tilde{\theta}_i }(\tilde{h}_{i-1}) \ \ \ \ \ s.t\ \ \ \tilde\theta_i =\Omega_{\psi_i}(\theta_i) \tag{2} hi=Fθ~i(h~i−1) s.t θ~i=Ωψi(θi)(2)
FCM
after WCM And layer processing and activation , We get a feature output , Next, we need to perform the feature output FCM. Use Ω λ i ( . ) \Omega_{\lambda_i}(.) Ωλi(.) Express FCM function . It's going on FCM when , The calibration parameters are multiplied by the output characteristics , As follows :
Ω λ i ( h i ) = { t i l e ( λ i ) ⊙ h i λ i ∈ R O (Convolution Layer) λ i ⊙ h i λ i ∈ R O (Fully Connected Layer) (3) \Omega_{\lambda_i}(h_i)=\begin{dcases} tile(\lambda_i)\odot h_i& \lambda_i\in\mathbb{R}^{O} \ \ \ \text{(Convolution Layer)}\\ \lambda_i\odot h_i&\lambda_i\in\mathbb{R}^{O} \ \ \ \ \ \ \text{(Fully Connected Layer)} \tag{3} \end{dcases} Ωλi(hi)={ tile(λi)⊙hiλi⊙hiλi∈RO (Convolution Layer)λi∈RO (Fully Connected Layer)(3)
After processing , similar resnet equally , Add the two feature outputs .
therefore , from i-1 Layer to i The complete processing of the layer is as follows :
h ~ i = σ ( B N ( Ω λ i ( F θ ~ i ( h ~ i − 1 ) ) ⊕ F θ ~ i ( h ~ i − 1 ) ) ) s . t θ ~ i = Ω ψ i ( θ i ) (4) \tilde{h}_i = \sigma(\ \mathcal{BN}\ (\ \Omega_{\lambda_i}\ (\mathcal{F}_{\tilde{\theta}_i }(\tilde{h}_{i-1}) )\oplus \mathcal{F}_{\tilde{\theta}_i }(\tilde{h}_{i-1}))) \ \ \ \ \ s.t\ \ \ \tilde\theta_i =\Omega_{\psi_i}(\theta_i) \tag{4} h~i=σ( BN ( Ωλi (Fθ~i(h~i−1))⊕Fθ~i(h~i−1))) s.t θ~i=Ωψi(θi)(4)
BN by batch normalization, σ \sigma σ Is the activation function
3.2 Parameter learning
After the layer is processed , We need to transform our loss function accordingly to better update the parameters . According to EWC Medium fisher Information as a basis for processing . The consolidation process takes place when training basic model parameters to absorb new knowledge , And rehearse the past knowledge by reproducing the data in the episodic memory , The following loss calculations can be made :
L c ( { ψ , λ , θ } , ( x , y , k ) ) = 1 2 v e c ( θ ~ − θ ~ t ) T Λ t ( θ ~ − θ ~ t ) ⏟ t e r m ( a ) + β D K L ( S ( z ^ τ ) ∥ S ( z k ^ τ ) ) ⏟ t e r m ( b ) (5) \mathcal{L_c}(\{\psi,\lambda,\theta\},(x,y,k)) = \underbrace{\frac{1}{2}vec(\tilde{\theta}-\tilde{\theta}^t)^T\Lambda_t(\tilde{\theta}-\tilde{\theta}^t)}_{term(a)}+\underbrace{\beta D_{KL}(S(\frac{\hat{z}}{\tau}) \parallel S(\frac{\hat{z_k}}{\tau}))}_{term(b)} \tag{5} Lc({ ψ,λ,θ},(x,y,k))=term(a)21vec(θ~−θ~t)TΛt(θ~−θ~t)+term(b)βDKL(S(τz^)∥S(τzk^))(5)
Loss of which β \beta β Is an equilibrium parameter ,S(.) For one softmax function , τ \tau τ by softmax Distillation temperature , z ^ \hat{z} z^ The predicted value for the current task , z k ^ \hat{z^k} zk^ Forecast for previous tasks .vec(.) Is to store the corresponding content in the data .
Λ t \Lambda_t Λt by EWC Medium fisher information From the loss of knowledge distillation in storage .term(a) It is the freezing of the parts of the weight that has guaranteed to deal with catastrophic forgetting , and term(b) For stability while training .
3.3 Optimize
NCCL The optimization of is similar to maml, It is divided into internal optimization and external optimization . Internal optimization update θ \theta θ, And external optimization updates ψ , λ \psi,\lambda ψ,λ. The optimization objective is :
Outer Loop: ( ψ ∗ , λ ∗ ) = a r g m i n ( ψ , λ ) L c ( ( ψ , λ ) , θ ∗ , M < t ) (6) \text{Outer Loop: }(\psi^*,\lambda^*) = argmin_{(\psi,\lambda)}\mathcal{L_c}((\psi,\lambda),\theta^*,\mathcal{M}_{<t}) \tag{6} Outer Loop: (ψ∗,λ∗)=argmin(ψ,λ)Lc((ψ,λ),θ∗,M<t)(6)
InnerLoop: θ ∗ = a r g m i n θ L b ( ( ψ , λ ) , θ , M < = t ) (7) \text{InnerLoop: }\theta^* = argmin_{\theta}\mathcal{L_b}((\psi,\lambda),\theta,\mathcal{M}_{<=t}) \tag{7} InnerLoop: θ∗=argminθLb((ψ,λ),θ,M<=t)(7)
Finally, set the learning rate to update the learning .
The complete algorithm process is shown in the figure :
Four . Code reading
The author's github Code point here
The difficulty of this article lies in the construction of network layer , That is to say WCM and FCM Two pieces of , And layer processing , So first let's look at this code .
The structure of each floor is shown in the figure
h ~ i = σ ( B N ( Ω λ i ( F θ ~ i ( h ~ i − 1 ) ) ⊕ F θ ~ i ( h ~ i − 1 ) ) ) s . t θ ~ i = Ω ψ i ( θ i ) \tilde{h}_i = \sigma(\ \mathcal{BN}\ (\ \Omega_{\lambda_i}\ (\mathcal{F}_{\tilde{\theta}_i }(\tilde{h}_{i-1}) )\oplus \mathcal{F}_{\tilde{\theta}_i }(\tilde{h}_{i-1}))) \ \ \ \ \ s.t\ \ \ \tilde\theta_i =\Omega_{\psi_i}(\theta_i) h~i=σ( BN ( Ωλi (Fθ~i(h~i−1))⊕Fθ~i(h~i−1))) s.t θ~i=Ωψi(θi)
among WCM by :
Ω ψ i ( θ i ) = { t i l e ( ψ i ) ⊙ θ i ψ i ∈ R O ∗ I (Convolution Layer) t i l e ( ψ i ) ⊙ θ i ψ i ∈ R O (Fully Connected Layer) \Omega_{\psi_i}(\theta_i)=\begin{dcases} tile(\psi_i)\odot \theta_i & \psi_i\in\mathbb{R}^{O*I} \ \ \ \text{(Convolution Layer)}\\ tile(\psi_i)\odot \theta_i &\psi_i\in\mathbb{R}^{O} \ \ \ \ \ \ \text{(Fully Connected Layer)} \end{dcases} Ωψi(θi)={ tile(ψi)⊙θitile(ψi)⊙θiψi∈RO∗I (Convolution Layer)ψi∈RO (Fully Connected Layer)
FCM by :
Ω λ i ( h i ) = { t i l e ( λ i ) ⊙ h i λ i ∈ R O (Convolution Layer) λ i ⊙ h i λ i ∈ R O (Fully Connected Layer) \Omega_{\lambda_i}(h_i)=\begin{dcases} tile(\lambda_i)\odot h_i& \lambda_i\in\mathbb{R}^{O} \ \ \ \text{(Convolution Layer)}\\ \lambda_i\odot h_i&\lambda_i\in\mathbb{R}^{O} \ \ \ \ \ \ \text{(Fully Connected Layer)} \end{dcases} Ωλi(hi)={ tile(λi)⊙hiλi⊙hiλi∈RO (Convolution Layer)λi∈RO (Fully Connected Layer)
First define the parameters of each layer according to the code :
The author here is to put two CNN The layer is defined as a layer after processing , As follows :
class CalibratedBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, activation='relu', norm='batch_norm', downsample=None):
super(CalibratedBlock, self).__init__()
## first floor CNN
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
## The second floor CNN
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.stride = stride
self.sigma = 0.05
self.downsample=downsample
## Define adjustment parameters ,cw(WCM Weight adjustment in ) cb( Convolutional bias) cf(FCM The output map of )
self.calib_w_conv1 = torch.nn.Parameter(torch.ones(planes, in_planes, 1, 1, ), requires_grad = True)
self.calib_b_conv1 = torch.nn.Parameter(torch.zeros([planes]), requires_grad = True)
self.calib_f_conv1 = torch.nn.Parameter(torch.ones([1, planes, 1, 1]), requires_grad = True)
self.calib_w_conv2 = torch.nn.Parameter(torch.ones(planes, planes, 1, 1, ), requires_grad = True)
self.calib_b_conv2 = torch.nn.Parameter(torch.zeros([planes, 1, 1, 1]), requires_grad = True)
self.calib_f_conv2 = torch.nn.Parameter(torch.ones([1, planes, 1, 1, ]), requires_grad = True)
## Put it in the model
self.register_parameter('calib_w_conv1', self.calib_w_conv1)
self.register_parameter('calib_b_conv1', self.calib_b_conv1)
self.register_parameter('calib_f_conv1', self.calib_f_conv1)
self.register_parameter('calib_w_conv2', self.calib_w_conv2)
self.register_parameter('calib_b_conv2', self.calib_b_conv2)
self.register_parameter('calib_f_conv2', self.calib_f_conv2)
# In addition, I will do an ordinary conv
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
self.activation = activation
self.norm = norm
Next, let's look at the specific forward The process :
def forward(self, x):
# compute the mask first
if self.activation == 'relu':
activation = nn.functional.relu
elif self.activation == 'leaky_relu':
activation = nn.functional.leaky_relu
else:
activation = None
[dim0, dim1] = self.conv1.weight.shape[2:]
calibrated_conv1 = self.calib_w_conv1
## perform tail operation , expand \psi
this_ss_weights = torch.tile(calibrated_conv1, (1, 1, dim0, dim1))
## Multiply to update our w
cw = self.conv1.weight * this_ss_weights
## Updated w Then perform convolution operation
conv_output = torch.nn.functional.conv2d(x, cw, stride=self.stride,
padding=1, bias=self.calib_b_conv1.squeeze())
## To calculate the h after , Conduct FCM
[dim1, dim2] = conv_output.shape[2:]
this_scale_weights = torch.tile(self.calib_f_conv1, (conv_output.shape[0], 1, dim1, dim2))
conv_output = conv_output * this_scale_weights
# normalize
if self.norm == 'batch_norm':
normed = self.bn1(conv_output)
elif self.norm == 'layer_norm':
normed = torch.nn.functional.layer_norm(conv_output)
else:
normed = conv_output
## Finally, activate
out = activation(normed)
#### Next, the same as above , To the last round out Do it again WCM FCM Complete operation
# second conv layer
[dim0, dim1] = self.conv2.weight.shape[2:]
# epsilon_weight = torch.randn(self.masked_conv2.shape).to(self.masked_conv2.device) * self.sigma
calibrated_conv2 = self.calib_w_conv2 #+ epsilon_weight * self.masked_conv2_sigma
this_ss_weights = torch.tile(calibrated_conv2, (1, 1, dim0, dim1))
cw = self.conv2.weight * this_ss_weights
# < resnet_conv_block_scale>
conv_output = torch.nn.functional.conv2d(out, cw,
stride=1, padding=1, bias=self.calib_b_conv2.squeeze())
[dim1, dim2] = conv_output.shape[2:]
this_scale_weights = torch.tile(self.calib_f_conv2, (conv_output.shape[0], 1, dim1, dim2))
conv_output = conv_output * this_scale_weights
# normalize
if self.norm == 'batch_norm':
normed = self.bn2(conv_output)
elif self.norm == 'layer_norm':
normed = torch.nn.functional.layer_norm(conv_output)
else:
normed = conv_output
out = activation(normed)
# residual
## The value of the original
residual = self.shortcut(x)
## Add directly
return out + residual
Then let's look at the corresponding inner update and outer update
First of all inner update. Assume that the current is t A mission , According to Article t Tasks use crossentropy To calculate the loss1, And then from the data we store [1,t) Select part of the data in the ( Here the author uses random selection , That is, the first task is to extract several data , The second task is to extract several data , And so on ). Then we calculate the loss2, Finally, calculate loss3=KL The divergence .loss2 and loss3 As learning from old tasks , Add to loss1 in . The specific process is as follows :
for step in range(self.inner_steps):
self.zero_grad()
self.opt.zero_grad()
offset1, offset2 = self.compute_offsets(t)
copy_net = copy.deepcopy(self.net)
# Select data from the current task and calculate the loss
if step == 0:
pred = self.forward(x, t)
pred = pred[:, offset1:offset2]
yy = y - offset1
elif self.count >= step * self.batch_size:
xx, yy, _, mask, list_t = self.memory_sampling(t, self.batch_size, intra_class=True)
pred = self.net(xx)
pred = torch.gather(pred, 1, mask)
else:
pred = self.forward(x, t)
pred = pred[:, offset1:offset2]
yy = y - offset1
# return 0.0
loss1 = self.bce(pred, yy)
## Pick data from old tasks , And calculate the loss
if t > 0:
xx, yy, feat, mask, list_t = self.memory_sampling(t, self.replay_batch_size)
pred_ = self.net(xx)
pred = torch.gather(pred_, 1, mask)
## Here is the loss of the old mission
loss2 = self.bce(pred, yy)
## Calculate divergence ,feat For stored previous data softmax Value
loss3 = self.reg * self.kl(F.log_softmax(pred / self.temp, dim=1), feat)
loss = loss1 + (loss2 + loss3) * self.gamma
else:
loss = loss1
## Gradient update
grads = torch.autograd.grad(loss, self.net.base_param(), create_graph=True, allow_unused=True, retain_graph=True)
# Update only \theta
num_none, num_grad = 0, 0
for param, grad in zip(self.net.base_param(), grads):
if grad is not None:
new_param = param.data.clone()
if self.inner_clip > 0:
grad.data.clamp_(-self.inner_clip, self.inner_clip)
new_param = new_param - self.inner_lr * grad
param.data.copy_(new_param)
num_grad += 1
else:
num_none += 1
inner After the update , We need to deal with it outer The loss of .outer Similarly, some old task data should be taken , Calculate according to the old task data KL Divergent loss, According to this loss After calculating the gradient , Then use the gradient to calculate EWC Medium fisher Information . Last use fisher The loss of information computation regularization (term (a)), Use this loss to update our ψ and λ \psi and \lambda ψ and λ that will do .
if t > 0:
self.net.zero_grad()
self.opt.zero_grad()
xval, yval, feat, mask, list_t = self.memory_sampling(t, self.batch_size)
pred_ = self.net(xval)
pred_ = torch.gather(pred_, 1, mask)
# 1st loss update
outer_loss = self.reg * self.kl(F.log_softmax(pred_ / self.temp, dim=1), feat)
outer_grad = torch.autograd.grad(outer_loss, self.net.context_param() + self.net.base_weight_params(),
retain_graph=True, allow_unused=True,)
# 2nd loss update
old_masked_params, _, _ = copy_net.base_and_calibrated_params()
cur_masked_params, cur_tiled_mask_params, cur_base_params = self.net.base_and_calibrated_params()
reg = self.beta * self.reg #* self.reg
ewc_loss = 0.0
num_meta_params = len(self.net.context_param())
for ii, p in enumerate(cur_masked_params):
## Here's the calculation fisher Information = (p.grad / tile(\psi) )^2
pg = (outer_grad[num_meta_params + ii].data/(cur_tiled_mask_params[ii].data +1e-12)).pow(2)
cur_loss = reg * pg.detach() * (p - old_masked_params[ii].data.clone()).pow(2)
ewc_loss += cur_loss.sum()
ewc_loss.backward()
self.opt.step()
边栏推荐
- d的扩大@nogc
- When SQL server2019 is installed, the next step cannot be performed. How to solve this problem?
- sql——课程实验考查
- Beginners can't tell the difference between framework and class library
- RT thread studio learning (VIII) connecting Alibaba cloud IOT with esp8266
- Kotlin plug-ins kotlin Android extensions
- 1. Foundation of MySQL database (1- installation and basic operation)
- 速度自关联函数—LAMMPS V.S MATALB
- [image denoising] image denoising based on nonlocal Euclidean median (nlem) with matlab code
- FCPX插件:简约线条呼出文字标题介绍动画Call Outs With Photo Placeholders for FCPX
猜你喜欢

Embedded gd32 code read protection

“我被大厂裁员了”

【图像去噪】基于高斯滤波、均值滤波、中值滤波、双边滤波四种滤波实现椒盐噪声图像去噪附matlab代码

Kotlin插件 kotlin-android-extensions

Federated meta learning with fast convergence and effective communication

esp32 hosted

应届生苦恼:是去华为拿1万多低薪,还是去互联网拿2万多高薪

【图像去噪】基于非局部欧几里德中值 (NLEM) 实现图像去噪附matlab代码

FCPX插件:简约线条呼出文字标题介绍动画Call Outs With Photo Placeholders for FCPX

2022年危险化学品经营单位安全管理人员特种作业证考试题库及答案
随机推荐
Scons编译IMGUI
The function of C language string Terminator
SQL Server 2019 installation error. How to solve it
Test left shift real introduction
SSM integration
Test manager defines and implements test metrics
Kali and programming: how to quickly build the OWASP website security test range?
Postman splice replacement parameter loop call interface
‘CMRESHandler‘ object has no attribute ‘_timer‘,socket.gaierror: [Errno 8] nodename nor servname pro
Database syntax related problems, solve a correct syntax
Detailed explanation of memory addressing in 8086 real address mode
Kotlin plug-ins kotlin Android extensions
[data clustering] data set, visualization and precautions are involved in this column
Class as a non type template parameter of the template
"I was laid off by a big factory"
【图像检测】基于深度差分和PCANet实现SAR图像变化检测附matlab代码
Detailed explanation of addressing mode in 8086
Curry carries the fourth game of the warriors against the Celtics
晶闸管,它是很重要的,交流控制器件
公众号也能带货?