当前位置:网站首页>How to calculate the number of parameters in the LSTM layer

How to calculate the number of parameters in the LSTM layer

2022-06-09 03:28:00 deephub

Long and short term memory network ( Often referred to as “ LSTM”) It's a special kind RNN, Well designed LSTM Be able to learn long-term dependence . Just like his name , It can learn long-term and short-term dependence .

Every LSTM Each floor has four doors :

  1. Forget gate
  2. Input gate
  3. New cell state gate
  4. Output gate

Let's calculate a LSTM Parameters of the unit :

every last lstm All operations are linear , So just calculate one and multiply by 4 That's all right. , Let's say Forget gate For example :

h(t-1) — Hidden layer unit from previous timestamps
x(t) — n-dimesnional unit vector
b- bias term

Because we already know h(t-1) and X(t) W_f and b_f Is unknown . Here we use LSTM To find the final w_f yes [h(t-1), x(t)] The joining together of .

W_f:num_units + input_dim: concat [h(t-1), x(t)]
b_f:1

So let's calculate the parameter formula :

num_param = no_of_gate(num_units + input_dim+1)

Throughout LSTM There are four doors in the floor , So the final equation is as follows .

num_param = 4(num_units + input_dim+1)

In practice , We don't just deal with individual LSTM cell. How to calculate multiple cell Parameters of ?

num_params = 4 * [(num_units + input_dim + 1) * num_units]

num_units = Hidden layer units from previous timestamps = output_dim

We actually calculate a lstm The number of parameters of

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.layers import Embedding
from keras.layers import LSTM
model = Sequential()
model.add(LSTM(200, input_dim=4096, input_length=16))
model.summary()

keras The result of the calculation is :

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 lstm_2 (LSTM)               (None, 200)               3437600   
                                                                 
=================================================================
Total params: 3,437,600
Trainable params: 3,437,600
Non-trainable params: 0
_________________________________________________________________

Let's use the above formula to calculate manually :

num_params = 4 * [(num_units + input_dim + 1) * num_units]
num_params = 4*[(200+4096+1) * 200]
num_params = 3437600

The result is the same

https://avoid.overfit.cn/post/ed5f0d482d5e486387f2708b7d0d58d8

author :Maheshmj

原网站

版权声明
本文为[deephub]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/159/202206080956214557.html