当前位置:网站首页>Hands on deep learning (39) -- gating cycle unit Gru
Hands on deep learning (39) -- gating cycle unit Gru
2022-07-04 09:41:00 【Stay a little star】
List of articles
Once again : This article mainly refers to teacher Li Mu B Stand hands-on learning in-depth learning courses for notes sorting and code reproduction . If you need to watch video , Sure click. Thank God mu for sharing !!!
Door control cycle unit (GRU)
Let's think about a sequence , Some early observations are very useful for all future observations , Some observations are useless for all future predictions , In other words, there are logical interrupts between various parts of some sequences . In summary :
- Not every observation is equally important
- To remember only relevant observations requires :
- Mechanisms that can be concerned ( Update door )
- The mechanism of forgetting ( Reset door )
Many methods have been proposed in academia to solve this problem . One of the earliest methods was " Long - Short term memory " (long-short-term memory, LSMT):(Hochreiter.Schmidhuber.1997
) . Door control cycle unit (gated recurrent unit, GRU)(Cho.Van-Merrienboer.Bahdanau.ea.2014
) It's a slightly simplified variant , It usually provides the same effect , And calculate (Chung.Gulcehre.Cho.ea.2014
) The speed is significantly faster . Because it is simpler , Let's start with the gating cycle unit .
To have come , Post these articles :
【1】LSTM:Long Short-Term Memory
【2】GRU:Learning Phrase Representations using RNN Encoder–Decoderfor Statistical Machine Translation
One 、 Door control hidden state
The key difference between ordinary recurrent neural network and gated cyclic unit is that the latter supports hidden state gating ( Or gating ). This means that there is a special mechanism to determine when to update Hidden state , And when Reset Hidden state . These mechanisms are learnable , And can solve the problems listed above .
for example , If the first mark is very important , We will learn not to update the hidden state after the first observation . Again , We can also learn to skip irrelevant temporary observations . Last , We will also learn to reset the hidden state when necessary .
1.1 Reset and update doors
The first thing we want to introduce is Reset door (reset gate) and Update door (update gate). We designed them to ( 0 , 1 ) (0, 1) (0,1) Vectors in intervals , So we can make convex combinations . for example , Resetting the door allows us to control the number of past states that we may also want to remember . Again , The update gate will allow us to control how many copies of the old state are in the new state .
Let's start by constructing these gating . The following figure describes the input of reset door and update door in the door control cycle unit , Input is given by the input of the current time step and the hidden state of the previous time step . The output of the two gates is determined by the use of sigmoid The two fully connected layers of the activation function give .
mathematical description , For a given time step t t t, Suppose the input is a small batch X t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d} Xt∈Rn×d ( Number of samples : n n n, Enter the number : d d d), The hidden state of the last time step is H t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h} Ht−1∈Rn×h( Number of hidden units : h h h). then , Reset door R t ∈ R n × h \mathbf{R}_t \in \mathbb{R}^{n \times h} Rt∈Rn×h And update the door Z t ∈ R n × h \mathbf{Z}_t \in \mathbb{R}^{n \times h} Zt∈Rn×h The calculation of is as follows :
R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , \begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z), \end{aligned} Rt=σ(XtWxr+Ht−1Whr+br),Zt=σ(XtWxz+Ht−1Whz+bz),
among W x r , W x z ∈ R d × h \mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h} Wxr,Wxz∈Rd×h and W h r , W h z ∈ R h × h \mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h} Whr,Whz∈Rh×h It's a weight parameter , b r , b z ∈ R 1 × h \mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h} br,bz∈R1×h Is the offset parameter . Please note that , During the summation process, the broadcast mechanism will be triggered ( see also :numref:subsec_broadcasting
). We use sigmoid function ( Such as :numref:sec_mlp
Introduced in ) Convert the input value to the interval ( 0 , 1 ) (0, 1) (0,1).
1.2 Candidate hidden status
Next , Let's reset the door R t \mathbf{R}_t Rt And RNN Integration of conventional implicit state update mechanism in , Get in time step t t t The hidden state of the candidate H ~ t ∈ R n × h \tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h} H~t∈Rn×h.
H ~ t = tanh ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h), H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh),
among W x h ∈ R d × h \mathbf{W}_{xh} \in \mathbb{R}^{d \times h} Wxh∈Rd×h and W h h ∈ R h × h \mathbf{W}_{hh} \in \mathbb{R}^{h \times h} Whh∈Rh×h It's a weight parameter , b h ∈ R 1 × h \mathbf{b}_h \in \mathbb{R}^{1 \times h} bh∈R1×h Is a bias item. , Symbol ⊙ \odot ⊙ It's hada code product ( Multiply by elements ) Operator . ad locum , We use tanh The nonlinear activation function ensures that the value in the candidate hidden state remains in the interval ( − 1 , 1 ) (-1, 1) (−1,1) in .
The result of the calculation is candidates (candidate), Because we still need to combine the operation of updating the door . And basic RNN comparison In the candidate hidden state R t \mathbf{R}_t Rt and H t − 1 \mathbf{H}_{t-1} Ht−1 The multiplication of elements can reduce the influence of previous states . Whenever the door is reset R t \mathbf{R}_t Rt The item in is close to 1 1 1 when , We restore a basic RNN Ordinary recurrent neural networks in . For reset doors R t \mathbf{R}_t Rt All approaches in 0 0 0 The item , The candidate hidden state is X t \mathbf{X}_t Xt The result of the multi-layer perceptron as input . therefore , Any pre-existing hidden state will be Reset As the default value . The following figure shows the calculation process after applying the reset door .
1.3 Hidden state
Last , We need to combine the renewal door Z t \mathbf{Z}_t Zt The effect of . This determines the new hidden state H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h} Ht∈Rn×h To what extent is the old state H t − 1 \mathbf{H}_{t-1} Ht−1 , And new candidate states H ~ t \tilde{\mathbf{H}}_t H~t Usage of . Update door Z t \mathbf{Z}_t Zt Only need to H t − 1 \mathbf{H}_{t-1} Ht−1 and H ~ t \tilde{\mathbf{H}}_t H~t This goal can be achieved by convex combination of elements . This leads to the final update formula of the gating cycle unit :
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t. Ht=Zt⊙Ht−1+(1−Zt)⊙H~t.
Whenever the door is updated Z t \mathbf{Z}_t Zt near 1 1 1 when , We just keep the old state . here , come from X t \mathbf{X}_t Xt The information is basically ignored , Thus effectively skipping the time step in the dependency chain t t t. contrary , When Z t \mathbf{Z}_t Zt near 0 0 0 when , New hidden state H t \mathbf{H}_t Ht Will approach the candidate's hidden state H ~ t \tilde{\mathbf{H}}_t H~t.== These designs can help us deal with the gradient vanishing problem in cyclic neural networks , And better capture the dependence of sequences with long time step distance .== for example , If the update gate of all time steps of the whole subsequence is close to 1 1 1, Regardless of the length of the sequence , The old hidden state at the beginning of the sequence will be easily retained and passed to the end of the sequence . The following figure illustrates the calculation flow after the update door works .
All in all , The gated circulation unit has the following two remarkable characteristics :
- Resetting the gate helps capture short-term dependencies in the sequence .
- Update gates help capture long-term dependencies in sequences .
Two 、 From zero GRU
import torch
from torch import nn
from d2l import torch as d2l
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
2.1 Initialize model parameters
The standard deviation is 0.01 Extract the weight from the Gaussian distribution , The offset is set to 0, Use the hyper parameter num_hidden
Define the number of hidden cells , Instantiate and update doors 、 Reset door 、 Candidate states and all weights and offsets associated with the output layer
def get_params(vocab_size,num_hiddens,device):
num_inputs= num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape,device=device)*0.01
def three():
return (normal((num_inputs,num_hiddens)),
normal((num_hiddens,num_hiddens)),
torch.zeros(num_hiddens,device=device))
W_xz,W_hz,b_z = three() # Update door parameters
W_xr,W_hr,b_r = three() # Reset door parameters
W_xh,W_hh,b_h = three() # Candidate status parameters
# Output layer parameters
W_hq = normal((num_hiddens,num_outputs))
b_q = torch.zeros(num_outputs,device=device)
# Additional gradient
# Additional gradient
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
2.2 Defining models
# Define an initialization function of hidden state , Returns a shape of ( Batch size , Number of hidden units ) Tensor , It's worth it all 0
def init_gru_state(batch_size,num_hiddens,device):
return (torch.zeros((batch_size,num_hiddens),device=device),)
R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , H ~ t = tanh ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z),\\ \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h),\\ \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t. \end{aligned} Rt=σ(XtWxr+Ht−1Whr+br),Zt=σ(XtWxz+Ht−1Whz+bz),H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh),Ht=Zt⊙Ht−1+(1−Zt)⊙H~t.
# Definition GRU Model
def gru(inputs,state,params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
H_tilda = torch.tanh((X @ W_xh) + ((R * H)@W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = H @ W_hq + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H,)
2.3 Training and Forecasting
Training and prediction work in the same way as RNN The implementation in is exactly the same . After training , We print out the confusion degree and prefix of the training set respectively “time traveler” and “traveler” The degree of confusion on the prediction sequence .
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 57290.3 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
traveller with a slight accession ofcheerfulness really thi
2.4 Concise implementation
senior API It contains all the configuration details described above , So you can directly instantiate GRU. It uses compiled operators to calculate , Instead of python Deal with many of the details
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs,num_hiddens)
model = d2l.RNNModel(gru_layer,len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 353447.0 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby
3、 ... and 、 Summary
- Gated recurrent neural network can better capture the dependencies on sequences with long time steps .
- Resetting the gate helps capture short-term dependencies in the sequence .
- Update gates help capture long-term dependencies in sequences .
- Reset when the door is open , The gated loop unit contains a basic loop neural network ; Update when the door is open , The gated loop unit can skip the subsequence .
Four 、 practice
- Suppose we just want to use time steps t ′ t' t′ Input to predict the time step t > t ′ t > t' t>t′ Output . For each time step , What is the best value to reset the door and update the door ?
Both the update and reset doors are 0 Indicates that the previous hidden status data is not used
- Adjust and analyze the effect of super parameters on running time 、 The influence of confusion and output order .
- Compare
rnn.RNN
andrnn.GRU
Different implementations of on runtime 、 The influence of confusion and output string . - If only a part of the gating cycle unit is realized , for example , What happens if there is only one reset door or one update door ?
边栏推荐
- Dynamic analysis and development prospect prediction report of high purity manganese dioxide in the world and China Ⓡ 2022 ~ 2027
- MySQL foundation 02 - installing MySQL in non docker version
- 智能网关助力提高工业数据采集和利用
- Sort out the power node, Mr. Wang he's SSM integration steps
- Launpad | Basics
- C # use gdi+ to add text to the picture and make the text adaptive to the rectangular area
- The 14th five year plan and investment risk analysis report of China's hydrogen fluoride industry 2022 ~ 2028
- How does idea withdraw code from remote push
- Global and Chinese market of planar waveguide optical splitter 2022-2028: Research Report on technology, participants, trends, market size and share
- Upgrading Xcode 12 caused Carthage to build cartfile containing only rxswift to fail
猜你喜欢
ASP. Net to access directory files outside the project website
libmysqlclient.so.20: cannot open shared object file: No such file or directory
Hands on deep learning (38) -- realize RNN from scratch
智能网关助力提高工业数据采集和利用
回复评论的sql
QTreeView+自定义Model实现示例
Logstack configuration details -- elasticstack (elk) work notes 020
How can people not love the amazing design of XXL job
Four common methods of copying object attributes (summarize the highest efficiency)
pcl::fromROSMsg报警告Failed to find match for field ‘intensity‘.
随机推荐
Latex download installation record
Golang defer
Upgrading Xcode 12 caused Carthage to build cartfile containing only rxswift to fail
Hands on deep learning (38) -- realize RNN from scratch
Reload CUDA and cudnn (for tensorflow and pytorch) [personal sorting summary]
At the age of 30, I changed to Hongmeng with a high salary because I did these three things
IIS configure FTP website
Reading notes on how to connect the network - tcp/ip connection (II)
Implementing expired localstorage cache with lazy deletion and scheduled deletion
How to display √ 2 on the command line terminal ̅? This is actually a blog's Unicode test article
C language pointer interview question - the second bullet
PHP student achievement management system, the database uses mysql, including source code and database SQL files, with the login management function of students and teachers
Flutter 小技巧之 ListView 和 PageView 的各種花式嵌套
Reading notes of how the network is connected - understanding the basic concepts of the network (I)
PHP personal album management system source code, realizes album classification and album grouping, as well as album image management. The database adopts Mysql to realize the login and registration f
How to write unit test cases
Report on the development trend and prospect trend of high purity zinc antimonide market in the world and China Ⓕ 2022 ~ 2027
2022-2028 global tensile strain sensor industry research and trend analysis report
Global and Chinese markets of water heaters in Saudi Arabia 2022-2028: Research Report on technology, participants, trends, market size and share
Flutter tips: various fancy nesting of listview and pageview