当前位置:网站首页>Hands on deep learning (39) -- gating cycle unit Gru
Hands on deep learning (39) -- gating cycle unit Gru
2022-07-04 09:41:00 【Stay a little star】
List of articles
Once again : This article mainly refers to teacher Li Mu B Stand hands-on learning in-depth learning courses for notes sorting and code reproduction . If you need to watch video , Sure click. Thank God mu for sharing !!!
Door control cycle unit (GRU)
Let's think about a sequence , Some early observations are very useful for all future observations , Some observations are useless for all future predictions , In other words, there are logical interrupts between various parts of some sequences . In summary :
- Not every observation is equally important
- To remember only relevant observations requires :
- Mechanisms that can be concerned ( Update door )
- The mechanism of forgetting ( Reset door )
Many methods have been proposed in academia to solve this problem . One of the earliest methods was " Long - Short term memory " (long-short-term memory, LSMT):(Hochreiter.Schmidhuber.1997
) . Door control cycle unit (gated recurrent unit, GRU)(Cho.Van-Merrienboer.Bahdanau.ea.2014
) It's a slightly simplified variant , It usually provides the same effect , And calculate (Chung.Gulcehre.Cho.ea.2014
) The speed is significantly faster . Because it is simpler , Let's start with the gating cycle unit .
To have come , Post these articles :
【1】LSTM:Long Short-Term Memory
【2】GRU:Learning Phrase Representations using RNN Encoder–Decoderfor Statistical Machine Translation
One 、 Door control hidden state
The key difference between ordinary recurrent neural network and gated cyclic unit is that the latter supports hidden state gating ( Or gating ). This means that there is a special mechanism to determine when to update Hidden state , And when Reset Hidden state . These mechanisms are learnable , And can solve the problems listed above .
for example , If the first mark is very important , We will learn not to update the hidden state after the first observation . Again , We can also learn to skip irrelevant temporary observations . Last , We will also learn to reset the hidden state when necessary .
1.1 Reset and update doors
The first thing we want to introduce is Reset door (reset gate) and Update door (update gate). We designed them to ( 0 , 1 ) (0, 1) (0,1) Vectors in intervals , So we can make convex combinations . for example , Resetting the door allows us to control the number of past states that we may also want to remember . Again , The update gate will allow us to control how many copies of the old state are in the new state .
Let's start by constructing these gating . The following figure describes the input of reset door and update door in the door control cycle unit , Input is given by the input of the current time step and the hidden state of the previous time step . The output of the two gates is determined by the use of sigmoid The two fully connected layers of the activation function give .

mathematical description , For a given time step t t t, Suppose the input is a small batch X t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d} Xt∈Rn×d ( Number of samples : n n n, Enter the number : d d d), The hidden state of the last time step is H t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h} Ht−1∈Rn×h( Number of hidden units : h h h). then , Reset door R t ∈ R n × h \mathbf{R}_t \in \mathbb{R}^{n \times h} Rt∈Rn×h And update the door Z t ∈ R n × h \mathbf{Z}_t \in \mathbb{R}^{n \times h} Zt∈Rn×h The calculation of is as follows :
R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , \begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z), \end{aligned} Rt=σ(XtWxr+Ht−1Whr+br),Zt=σ(XtWxz+Ht−1Whz+bz),
among W x r , W x z ∈ R d × h \mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h} Wxr,Wxz∈Rd×h and W h r , W h z ∈ R h × h \mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h} Whr,Whz∈Rh×h It's a weight parameter , b r , b z ∈ R 1 × h \mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h} br,bz∈R1×h Is the offset parameter . Please note that , During the summation process, the broadcast mechanism will be triggered ( see also :numref:subsec_broadcasting
). We use sigmoid function ( Such as :numref:sec_mlp
Introduced in ) Convert the input value to the interval ( 0 , 1 ) (0, 1) (0,1).
1.2 Candidate hidden status
Next , Let's reset the door R t \mathbf{R}_t Rt And RNN Integration of conventional implicit state update mechanism in , Get in time step t t t The hidden state of the candidate H ~ t ∈ R n × h \tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h} H~t∈Rn×h.
H ~ t = tanh ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h), H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh),
among W x h ∈ R d × h \mathbf{W}_{xh} \in \mathbb{R}^{d \times h} Wxh∈Rd×h and W h h ∈ R h × h \mathbf{W}_{hh} \in \mathbb{R}^{h \times h} Whh∈Rh×h It's a weight parameter , b h ∈ R 1 × h \mathbf{b}_h \in \mathbb{R}^{1 \times h} bh∈R1×h Is a bias item. , Symbol ⊙ \odot ⊙ It's hada code product ( Multiply by elements ) Operator . ad locum , We use tanh The nonlinear activation function ensures that the value in the candidate hidden state remains in the interval ( − 1 , 1 ) (-1, 1) (−1,1) in .
The result of the calculation is candidates (candidate), Because we still need to combine the operation of updating the door . And basic RNN comparison In the candidate hidden state R t \mathbf{R}_t Rt and H t − 1 \mathbf{H}_{t-1} Ht−1 The multiplication of elements can reduce the influence of previous states . Whenever the door is reset R t \mathbf{R}_t Rt The item in is close to 1 1 1 when , We restore a basic RNN Ordinary recurrent neural networks in . For reset doors R t \mathbf{R}_t Rt All approaches in 0 0 0 The item , The candidate hidden state is X t \mathbf{X}_t Xt The result of the multi-layer perceptron as input . therefore , Any pre-existing hidden state will be Reset As the default value . The following figure shows the calculation process after applying the reset door .

1.3 Hidden state
Last , We need to combine the renewal door Z t \mathbf{Z}_t Zt The effect of . This determines the new hidden state H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h} Ht∈Rn×h To what extent is the old state H t − 1 \mathbf{H}_{t-1} Ht−1 , And new candidate states H ~ t \tilde{\mathbf{H}}_t H~t Usage of . Update door Z t \mathbf{Z}_t Zt Only need to H t − 1 \mathbf{H}_{t-1} Ht−1 and H ~ t \tilde{\mathbf{H}}_t H~t This goal can be achieved by convex combination of elements . This leads to the final update formula of the gating cycle unit :
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t. Ht=Zt⊙Ht−1+(1−Zt)⊙H~t.
Whenever the door is updated Z t \mathbf{Z}_t Zt near 1 1 1 when , We just keep the old state . here , come from X t \mathbf{X}_t Xt The information is basically ignored , Thus effectively skipping the time step in the dependency chain t t t. contrary , When Z t \mathbf{Z}_t Zt near 0 0 0 when , New hidden state H t \mathbf{H}_t Ht Will approach the candidate's hidden state H ~ t \tilde{\mathbf{H}}_t H~t.== These designs can help us deal with the gradient vanishing problem in cyclic neural networks , And better capture the dependence of sequences with long time step distance .== for example , If the update gate of all time steps of the whole subsequence is close to 1 1 1, Regardless of the length of the sequence , The old hidden state at the beginning of the sequence will be easily retained and passed to the end of the sequence . The following figure illustrates the calculation flow after the update door works .

All in all , The gated circulation unit has the following two remarkable characteristics :
- Resetting the gate helps capture short-term dependencies in the sequence .
- Update gates help capture long-term dependencies in sequences .
Two 、 From zero GRU
import torch
from torch import nn
from d2l import torch as d2l
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
2.1 Initialize model parameters
The standard deviation is 0.01 Extract the weight from the Gaussian distribution , The offset is set to 0, Use the hyper parameter num_hidden
Define the number of hidden cells , Instantiate and update doors 、 Reset door 、 Candidate states and all weights and offsets associated with the output layer
def get_params(vocab_size,num_hiddens,device):
num_inputs= num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape,device=device)*0.01
def three():
return (normal((num_inputs,num_hiddens)),
normal((num_hiddens,num_hiddens)),
torch.zeros(num_hiddens,device=device))
W_xz,W_hz,b_z = three() # Update door parameters
W_xr,W_hr,b_r = three() # Reset door parameters
W_xh,W_hh,b_h = three() # Candidate status parameters
# Output layer parameters
W_hq = normal((num_hiddens,num_outputs))
b_q = torch.zeros(num_outputs,device=device)
# Additional gradient
# Additional gradient
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
2.2 Defining models
# Define an initialization function of hidden state , Returns a shape of ( Batch size , Number of hidden units ) Tensor , It's worth it all 0
def init_gru_state(batch_size,num_hiddens,device):
return (torch.zeros((batch_size,num_hiddens),device=device),)
R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , H ~ t = tanh ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z),\\ \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h),\\ \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t. \end{aligned} Rt=σ(XtWxr+Ht−1Whr+br),Zt=σ(XtWxz+Ht−1Whz+bz),H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh),Ht=Zt⊙Ht−1+(1−Zt)⊙H~t.

# Definition GRU Model
def gru(inputs,state,params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
H_tilda = torch.tanh((X @ W_xh) + ((R * H)@W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = H @ W_hq + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H,)
2.3 Training and Forecasting
Training and prediction work in the same way as RNN The implementation in is exactly the same . After training , We print out the confusion degree and prefix of the training set respectively “time traveler” and “traveler” The degree of confusion on the prediction sequence .
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 57290.3 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
traveller with a slight accession ofcheerfulness really thi

2.4 Concise implementation
senior API It contains all the configuration details described above , So you can directly instantiate GRU. It uses compiled operators to calculate , Instead of python Deal with many of the details
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs,num_hiddens)
model = d2l.RNNModel(gru_layer,len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 353447.0 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

3、 ... and 、 Summary
- Gated recurrent neural network can better capture the dependencies on sequences with long time steps .
- Resetting the gate helps capture short-term dependencies in the sequence .
- Update gates help capture long-term dependencies in sequences .
- Reset when the door is open , The gated loop unit contains a basic loop neural network ; Update when the door is open , The gated loop unit can skip the subsequence .
Four 、 practice
- Suppose we just want to use time steps t ′ t' t′ Input to predict the time step t > t ′ t > t' t>t′ Output . For each time step , What is the best value to reset the door and update the door ?
Both the update and reset doors are 0 Indicates that the previous hidden status data is not used
- Adjust and analyze the effect of super parameters on running time 、 The influence of confusion and output order .
- Compare
rnn.RNN
andrnn.GRU
Different implementations of on runtime 、 The influence of confusion and output string . - If only a part of the gating cycle unit is realized , for example , What happens if there is only one reset door or one update door ?
边栏推荐
- 2022-2028 global small batch batch batch furnace industry research and trend analysis report
- What is uid? What is auth? What is a verifier?
- UML sequence diagram [easy to understand]
- Golang 类型比较
- DR6018-CP01-wifi6-Qualcomm-IPQ6010-IPQ6018-FAMILY-2T2R-2.5G-ETH-port-CP01-802-11AX-MU-MIMO-OFDMA
- Hands on deep learning (36) -- language model and data set
- PMP registration process and precautions
- 2022-2028 global probiotics industry research and trend analysis report
- Write a jison parser from scratch (3/10): a good beginning is half the success -- "politics" (Aristotle)
- 【leetcode】29. Divide two numbers
猜你喜欢
C # use ffmpeg for audio transcoding
ASP. Net to access directory files outside the project website
PHP is used to add, modify and delete movie information, which is divided into foreground management and background management. Foreground users can browse information and post messages, and backgroun
技术管理进阶——如何设计并跟进不同层级同学的绩效
26. Delete duplicates in the ordered array (fast and slow pointer de duplication)
Write a mobile date selector component by yourself
xxl-job惊艳的设计,怎能叫人不爱
Hands on deep learning (32) -- fully connected convolutional neural network FCN
智慧路灯杆水库区安全监测应用
QTreeView+自定义Model实现示例
随机推荐
2022-2028 global visual quality analyzer industry research and trend analysis report
Svg image quoted from CodeChina
Report on the development trend and prospect trend of high purity zinc antimonide market in the world and China Ⓕ 2022 ~ 2027
Golang Modules
PHP personal album management system source code, realizes album classification and album grouping, as well as album image management. The database adopts Mysql to realize the login and registration f
How should PMP learning ideas be realized?
Problems encountered by scan, scanf and scanln in golang
Write a mobile date selector component by yourself
How do microservices aggregate API documents? This wave of show~
Dynamic analysis and development prospect prediction report of high purity manganese dioxide in the world and China Ⓡ 2022 ~ 2027
QTreeView+自定义Model实现示例
PHP book borrowing management system, with complete functions, supports user foreground management and background management, and supports the latest version of PHP 7 x. Database mysql
Sort out the power node, Mr. Wang he's SSM integration steps
Matlab tips (25) competitive neural network and SOM neural network
Hands on deep learning (32) -- fully connected convolutional neural network FCN
SSM online examination system source code, database using mysql, online examination system, fully functional, randomly generated question bank, supporting a variety of question types, students, teache
Golang Modules
Les différents modèles imbriqués de listview et Pageview avec les conseils de flutter
Global and Chinese markets for laser assisted liposuction (LAL) devices 2022-2028: Research Report on technology, participants, trends, market size and share
Fatal error in golang: concurrent map writes