当前位置:网站首页>Hands on deep learning (39) -- gating cycle unit Gru
Hands on deep learning (39) -- gating cycle unit Gru
2022-07-04 09:41:00 【Stay a little star】
List of articles
Once again : This article mainly refers to teacher Li Mu B Stand hands-on learning in-depth learning courses for notes sorting and code reproduction . If you need to watch video , Sure click. Thank God mu for sharing !!!
Door control cycle unit (GRU)
Let's think about a sequence , Some early observations are very useful for all future observations , Some observations are useless for all future predictions , In other words, there are logical interrupts between various parts of some sequences . In summary :
- Not every observation is equally important
- To remember only relevant observations requires :
- Mechanisms that can be concerned ( Update door )
- The mechanism of forgetting ( Reset door )
Many methods have been proposed in academia to solve this problem . One of the earliest methods was " Long - Short term memory " (long-short-term memory, LSMT):(Hochreiter.Schmidhuber.1997) . Door control cycle unit (gated recurrent unit, GRU)(Cho.Van-Merrienboer.Bahdanau.ea.2014) It's a slightly simplified variant , It usually provides the same effect , And calculate (Chung.Gulcehre.Cho.ea.2014) The speed is significantly faster . Because it is simpler , Let's start with the gating cycle unit .
To have come , Post these articles :
【1】LSTM:Long Short-Term Memory
【2】GRU:Learning Phrase Representations using RNN Encoder–Decoderfor Statistical Machine Translation
One 、 Door control hidden state
The key difference between ordinary recurrent neural network and gated cyclic unit is that the latter supports hidden state gating ( Or gating ). This means that there is a special mechanism to determine when to update Hidden state , And when Reset Hidden state . These mechanisms are learnable , And can solve the problems listed above .
for example , If the first mark is very important , We will learn not to update the hidden state after the first observation . Again , We can also learn to skip irrelevant temporary observations . Last , We will also learn to reset the hidden state when necessary .
1.1 Reset and update doors
The first thing we want to introduce is Reset door (reset gate) and Update door (update gate). We designed them to ( 0 , 1 ) (0, 1) (0,1) Vectors in intervals , So we can make convex combinations . for example , Resetting the door allows us to control the number of past states that we may also want to remember . Again , The update gate will allow us to control how many copies of the old state are in the new state .
Let's start by constructing these gating . The following figure describes the input of reset door and update door in the door control cycle unit , Input is given by the input of the current time step and the hidden state of the previous time step . The output of the two gates is determined by the use of sigmoid The two fully connected layers of the activation function give .

mathematical description , For a given time step t t t, Suppose the input is a small batch X t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d} Xt∈Rn×d ( Number of samples : n n n, Enter the number : d d d), The hidden state of the last time step is H t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h} Ht−1∈Rn×h( Number of hidden units : h h h). then , Reset door R t ∈ R n × h \mathbf{R}_t \in \mathbb{R}^{n \times h} Rt∈Rn×h And update the door Z t ∈ R n × h \mathbf{Z}_t \in \mathbb{R}^{n \times h} Zt∈Rn×h The calculation of is as follows :
R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , \begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z), \end{aligned} Rt=σ(XtWxr+Ht−1Whr+br),Zt=σ(XtWxz+Ht−1Whz+bz),
among W x r , W x z ∈ R d × h \mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h} Wxr,Wxz∈Rd×h and W h r , W h z ∈ R h × h \mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h} Whr,Whz∈Rh×h It's a weight parameter , b r , b z ∈ R 1 × h \mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h} br,bz∈R1×h Is the offset parameter . Please note that , During the summation process, the broadcast mechanism will be triggered ( see also :numref:subsec_broadcasting ). We use sigmoid function ( Such as :numref:sec_mlp Introduced in ) Convert the input value to the interval ( 0 , 1 ) (0, 1) (0,1).
1.2 Candidate hidden status
Next , Let's reset the door R t \mathbf{R}_t Rt And RNN Integration of conventional implicit state update mechanism in , Get in time step t t t The hidden state of the candidate H ~ t ∈ R n × h \tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h} H~t∈Rn×h.
H ~ t = tanh ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h), H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh),
among W x h ∈ R d × h \mathbf{W}_{xh} \in \mathbb{R}^{d \times h} Wxh∈Rd×h and W h h ∈ R h × h \mathbf{W}_{hh} \in \mathbb{R}^{h \times h} Whh∈Rh×h It's a weight parameter , b h ∈ R 1 × h \mathbf{b}_h \in \mathbb{R}^{1 \times h} bh∈R1×h Is a bias item. , Symbol ⊙ \odot ⊙ It's hada code product ( Multiply by elements ) Operator . ad locum , We use tanh The nonlinear activation function ensures that the value in the candidate hidden state remains in the interval ( − 1 , 1 ) (-1, 1) (−1,1) in .
The result of the calculation is candidates (candidate), Because we still need to combine the operation of updating the door . And basic RNN comparison In the candidate hidden state R t \mathbf{R}_t Rt and H t − 1 \mathbf{H}_{t-1} Ht−1 The multiplication of elements can reduce the influence of previous states . Whenever the door is reset R t \mathbf{R}_t Rt The item in is close to 1 1 1 when , We restore a basic RNN Ordinary recurrent neural networks in . For reset doors R t \mathbf{R}_t Rt All approaches in 0 0 0 The item , The candidate hidden state is X t \mathbf{X}_t Xt The result of the multi-layer perceptron as input . therefore , Any pre-existing hidden state will be Reset As the default value . The following figure shows the calculation process after applying the reset door .

1.3 Hidden state
Last , We need to combine the renewal door Z t \mathbf{Z}_t Zt The effect of . This determines the new hidden state H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h} Ht∈Rn×h To what extent is the old state H t − 1 \mathbf{H}_{t-1} Ht−1 , And new candidate states H ~ t \tilde{\mathbf{H}}_t H~t Usage of . Update door Z t \mathbf{Z}_t Zt Only need to H t − 1 \mathbf{H}_{t-1} Ht−1 and H ~ t \tilde{\mathbf{H}}_t H~t This goal can be achieved by convex combination of elements . This leads to the final update formula of the gating cycle unit :
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t. Ht=Zt⊙Ht−1+(1−Zt)⊙H~t.
Whenever the door is updated Z t \mathbf{Z}_t Zt near 1 1 1 when , We just keep the old state . here , come from X t \mathbf{X}_t Xt The information is basically ignored , Thus effectively skipping the time step in the dependency chain t t t. contrary , When Z t \mathbf{Z}_t Zt near 0 0 0 when , New hidden state H t \mathbf{H}_t Ht Will approach the candidate's hidden state H ~ t \tilde{\mathbf{H}}_t H~t.== These designs can help us deal with the gradient vanishing problem in cyclic neural networks , And better capture the dependence of sequences with long time step distance .== for example , If the update gate of all time steps of the whole subsequence is close to 1 1 1, Regardless of the length of the sequence , The old hidden state at the beginning of the sequence will be easily retained and passed to the end of the sequence . The following figure illustrates the calculation flow after the update door works .

All in all , The gated circulation unit has the following two remarkable characteristics :
- Resetting the gate helps capture short-term dependencies in the sequence .
- Update gates help capture long-term dependencies in sequences .
Two 、 From zero GRU
import torch
from torch import nn
from d2l import torch as d2l
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
2.1 Initialize model parameters
The standard deviation is 0.01 Extract the weight from the Gaussian distribution , The offset is set to 0, Use the hyper parameter num_hidden Define the number of hidden cells , Instantiate and update doors 、 Reset door 、 Candidate states and all weights and offsets associated with the output layer
def get_params(vocab_size,num_hiddens,device):
num_inputs= num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape,device=device)*0.01
def three():
return (normal((num_inputs,num_hiddens)),
normal((num_hiddens,num_hiddens)),
torch.zeros(num_hiddens,device=device))
W_xz,W_hz,b_z = three() # Update door parameters
W_xr,W_hr,b_r = three() # Reset door parameters
W_xh,W_hh,b_h = three() # Candidate status parameters
# Output layer parameters
W_hq = normal((num_hiddens,num_outputs))
b_q = torch.zeros(num_outputs,device=device)
# Additional gradient
# Additional gradient
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
2.2 Defining models
# Define an initialization function of hidden state , Returns a shape of ( Batch size , Number of hidden units ) Tensor , It's worth it all 0
def init_gru_state(batch_size,num_hiddens,device):
return (torch.zeros((batch_size,num_hiddens),device=device),)
R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , H ~ t = tanh ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z),\\ \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h),\\ \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t. \end{aligned} Rt=σ(XtWxr+Ht−1Whr+br),Zt=σ(XtWxz+Ht−1Whz+bz),H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh),Ht=Zt⊙Ht−1+(1−Zt)⊙H~t.

# Definition GRU Model
def gru(inputs,state,params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
H_tilda = torch.tanh((X @ W_xh) + ((R * H)@W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = H @ W_hq + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H,)
2.3 Training and Forecasting
Training and prediction work in the same way as RNN The implementation in is exactly the same . After training , We print out the confusion degree and prefix of the training set respectively “time traveler” and “traveler” The degree of confusion on the prediction sequence .
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 57290.3 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
traveller with a slight accession ofcheerfulness really thi

2.4 Concise implementation
senior API It contains all the configuration details described above , So you can directly instantiate GRU. It uses compiled operators to calculate , Instead of python Deal with many of the details
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs,num_hiddens)
model = d2l.RNNModel(gru_layer,len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 353447.0 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

3、 ... and 、 Summary
- Gated recurrent neural network can better capture the dependencies on sequences with long time steps .
- Resetting the gate helps capture short-term dependencies in the sequence .
- Update gates help capture long-term dependencies in sequences .
- Reset when the door is open , The gated loop unit contains a basic loop neural network ; Update when the door is open , The gated loop unit can skip the subsequence .
Four 、 practice
- Suppose we just want to use time steps t ′ t' t′ Input to predict the time step t > t ′ t > t' t>t′ Output . For each time step , What is the best value to reset the door and update the door ?
Both the update and reset doors are 0 Indicates that the previous hidden status data is not used
- Adjust and analyze the effect of super parameters on running time 、 The influence of confusion and output order .
- Compare
rnn.RNNandrnn.GRUDifferent implementations of on runtime 、 The influence of confusion and output string . - If only a part of the gating cycle unit is realized , for example , What happens if there is only one reset door or one update door ?
边栏推荐
- Flutter 小技巧之 ListView 和 PageView 的各種花式嵌套
- Development trend and market demand analysis report of high purity tin chloride in the world and China Ⓔ 2022 ~ 2027
- On Multus CNI
- Solution to null JSON after serialization in golang
- QTreeView+自定义Model实现示例
- How to display √ 2 on the command line terminal ̅? This is actually a blog's Unicode test article
- Leetcode (Sword finger offer) - 35 Replication of complex linked list
- 2022-2028 global optical transparency industry research and trend analysis report
- 2022-2028 global elastic strain sensor industry research and trend analysis report
- 2022-2028 global intelligent interactive tablet industry research and trend analysis report
猜你喜欢

MySQL foundation 02 - installing MySQL in non docker version

Latex download installation record

ArrayBuffer

The child container margin top acts on the parent container

How web pages interact with applets

2022-2028 global elastic strain sensor industry research and trend analysis report

SSM online examination system source code, database using mysql, online examination system, fully functional, randomly generated question bank, supporting a variety of question types, students, teache

智慧路灯杆水库区安全监测应用

Write a mobile date selector component by yourself

Fabric of kubernetes CNI plug-in
随机推荐
Les différents modèles imbriqués de listview et Pageview avec les conseils de flutter
Launpad | basic knowledge
What is permission? What is a role? What are users?
Tkinter Huarong Road 4x4 tutorial II
C语言指针经典面试题——第一弹
Reading notes on how to connect the network - hubs, routers and routers (III)
Upgrading Xcode 12 caused Carthage to build cartfile containing only rxswift to fail
"How to connect the Internet" reading notes - FTTH
Hands on deep learning (36) -- language model and data set
Solution to null JSON after serialization in golang
PHP is used to add, modify and delete movie information, which is divided into foreground management and background management. Foreground users can browse information and post messages, and backgroun
Write a jison parser from scratch (4/10): detailed explanation of the syntax format of the jison parser generator
Kotlin 集合操作汇总
You can see the employment prospects of PMP project management
xxl-job惊艳的设计,怎能叫人不爱
Launpad | 基礎知識
The 14th five year plan and investment risk analysis report of China's hydrogen fluoride industry 2022 ~ 2028
Rules for using init in golang
2022-2028 global optical transparency industry research and trend analysis report
Write a jison parser from scratch (6/10): parse, not define syntax