当前位置:网站首页>Tensorflow2.0 (XII) -- realize simple RNN and LSTM networks
Tensorflow2.0 (XII) -- realize simple RNN and LSTM networks
2022-07-28 17:56:00 【A bone loving cat】
Implement a simple RNN And LSTM The Internet
Preface
Last post TensorFlow2.0( 11、 ... and )– understand LSTM The Internet We explained in detail LSTM Working principle and structure of , We passed this blog post IMDB Data set to achieve simple RNN And LSTM The network classifies texts .
1. Import the corresponding library
First, we need to import the corresponding python library
# matplotlib For drawing
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
# A library for processing data
import numpy as np
import sklearn
import pandas as pd
# System libraries
import os
import sys
import time
# TensorFlow The library of
import tensorflow as tf
from tensorflow import keras
2. Load and build datasets
2.1 Load data set
The data set loaded this time is IMDB, It is a movie scoring data set , There are two kinds of labels ,positive and negative
# IMDB: A movie rating data set , There are two types of ,positive And negative
imdb = keras.datasets.imdb
vocab_size = 10000
index_from = 3
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(
num_words = vocab_size, # The number of thesauri in the data , According to the frequency of words , front 10000 Will be kept , The rest are treated as special characters
index_from = index_from) # The subscript of the vocabulary is from 3 Start calculating
Let's take a look at the data :
print(train_data[0],train_labels[0])
print(train_data.shape, train_labels.shape) # Each training sample is longer
print(len(train_data[0]), len(train_data[1])) # For example, the length of the first sample is 218, The length of the second sample is 189
print(test_data.shape, test_labels.shape)

2.2 Building a vocabulary
First, we need to get the vocabulary :
word_index = imdb.get_word_index() # Get Thesaurus
print(len(word_index)) # Print the length of the vocabulary
print(list(word_index.items())[:50]) # Before printing the vocabulary 50 individual key:value form
The output is :
You can see that our vocabulary is in the form of a dictionary ,key For the word ,value Is the corresponding ID. Let's generate id To word Mapping :
""" Change Thesaurus ID Because when we read the vocabulary, we subscript 3 When I started to calculate , So add 3 The purpose of offsetting the coordinates of the vocabulary is to add some special characters """
word_index = {
k:(v+3) for k, v in word_index.items()}
word_index['<PAD>'] = 0 # padding Characters used to fill in
word_index['<START>'] = 1 # The character before the beginning of each sentence
word_index['<UNK>'] = 2 # Unrecognized character
word_index['<END>'] = 3 # The character at the end of each sentence
# id->word The index of
reverse_word_index = dict(
[(value, key) for key, value in word_index.items()])
def decode_review(text_ids):
return ' '.join(
[reverse_word_index.get(word_id, "<UNK>") for word_id in text_ids]) # Not found id The default with <UNK> Instead of
# Print train_data[0] in ID The corresponding statement
decode_review(train_data[0])
Let's take a look at the transformed sample :
2.3 Processing data
Because our samples are unequal , The length of each comment is different , So we need to process the data , Specify a sample length , Make up the samples with insufficient length , Truncate the samples whose length exceeds .
max_length = 500 # The length of the sentence , The length is less than 500 Sentences will be padding A filling , The length is less than 500 The sentence will be truncated
""" utilize keras.preprocessing.sequence.pad_sequences Supplement and truncate the data of training set and test set """
# Process training set data
train_data = keras.preprocessing.sequence.pad_sequences(
train_data, # Data to process
value = word_index['<PAD>'], # The value to fill in
padding = 'post', # padding The order of :post Refer to padding Put it at the end of the sentence , pre Refer to padding Put it in front of the sentence
maxlen = max_length) # Maximum length
# Processing test set data
test_data = keras.preprocessing.sequence.pad_sequences(
test_data, # Data to process
value = word_index['<PAD>'], # The value to fill in
padding = 'post', # padding The order of :post Refer to padding Put it at the end of the sentence , pre Refer to padding Put it in front of the sentence
maxlen = max_length) # Maximum length
print(train_data[0])
The output is :
3. Build a simple RNN Model
3.1 A one-way RNN Model
First, we build a one-way RNN Model :
embedding_dim = 16 # Every word embedding Into a length of 16 Vector
batch_size = 512 # Every batch The length of
""" Single layer one-way RNN """
model = keras.models.Sequential([
""" embedding The role of layers 1. Define a matrix matrix: [vocab_size, embedding_dim] ([10000, 16]) 2. For each sample [1,2,3,4..], Turn it into max_length * embedding_dim Dimension data , That is, every word becomes... In length 16 Vector 3. The final data is a three-dimensional matrix :batch_size * max_length * embedding_dim """
keras.layers.Embedding(vocab_size, # The length of the Thesaurus
embedding_dim, # embedding The length of
input_length = max_length), # Length of input
# units: Output space dimension
# return_sequences: Boolean value . Is the last output in the return output sequence , Or the whole sequence .
keras.layers.SimpleRNN(units = 64, return_sequences = False),
# Fully connected layer
keras.layers.Dense(64, activation = 'relu'),
keras.layers.Dense(1, activation='sigmoid'),
])
model.summary()
Let's take a look at one-way SimpleRNN Structure :
3.2 two-way RNN Model
If one-way RNN If the model effect is not good , One way RNN The model becomes bidirectional , utilize keras.layers.Bidirectional() The module can :
embedding_dim = 16
batch_size = 512
model = keras.models.Sequential([
# 1. define matrix: [vocab_size, embedding_dim]
# 2. [1,2,3,4..], max_length * embedding_dim
# 3. batch_size * max_length * embedding_dim
keras.layers.Embedding(vocab_size, embedding_dim,
input_length = max_length),
keras.layers.Bidirectional(
keras.layers.SimpleRNN(
units = 64, return_sequences = True)),
keras.layers.Dense(64, activation = 'relu'),
keras.layers.Dense(1, activation='sigmoid'),
])
model.summary()
model.compile(optimizer = 'adam',
loss = 'binary_crossentropy',
metrics = ['accuracy'])
4. structure LSTM Model
4.1 A one-way LSTM Model
And build RNN equally , structure LSTM The steps are almost the same , Only need to keras.layers.SimpleRNN() Change layer to keras.layers.LSTM() that will do :
embedding_dim = 16 # Every word embedding Into a length of 16 Vector
batch_size = 512 # Every batch The length of
""" Single layer one-way LSTM """
model = keras.models.Sequential([
""" embedding The role of layers 1. Define a matrix matrix: [vocab_size, embedding_dim] ([10000, 16]) 2. For each sample [1,2,3,4..], Turn it into max_length * embedding_dim Dimension data , That is, every word becomes... In length 16 Vector 3. The final data is a three-dimensional matrix :batch_size * max_length * embedding_dim """
keras.layers.Embedding(vocab_size, # The length of the Thesaurus
embedding_dim, # embedding The length of
input_length = max_length), # Length of input
# units: Output space dimension
# return_sequences: Boolean value . Is the last output in the return output sequence , Or the whole sequence .
keras.layers.LSTM(units = 64, return_sequences = False),
keras.layers.Dense(64, activation = 'relu'),
keras.layers.Dense(1, activation='sigmoid'),
])
model.summary()
model.compile(optimizer = 'adam',
loss = 'binary_crossentropy',
metrics = ['accuracy'])
4.1 two-way LSTM Model
two-way LSTM Model and bidirectional RNN The models are almost identical , Only need to keras.layers.SimpleRNN() Change layer to keras.layers.LSTM() that will do :
embedding_dim = 16
batch_size = 512
""" Two way double-layer LSTM """
model = keras.models.Sequential([
# 1. define matrix: [vocab_size, embedding_dim]
# 2. [1,2,3,4..], max_length * embedding_dim
# 3. batch_size * max_length * embedding_dim
keras.layers.Embedding(vocab_size, embedding_dim,
input_length = max_length),
keras.layers.Bidirectional(
keras.layers.LSTM(
units = 64, return_sequences = True)),
keras.layers.Dense(64, activation = 'relu'),
keras.layers.Dense(1, activation='sigmoid'),
])
model.summary()
model.compile(optimizer = 'adam',
loss = 'binary_crossentropy',
metrics = ['accuracy'])
5. Model compilation and training
For built model, We can compile and train :
single_rnn_model.compile(optimizer = 'adam',
loss = 'binary_crossentropy',
metrics = ['accuracy'])
history_single_rnn = single_rnn_model.fit(
train_data, train_labels,
epochs = 30,
batch_size = batch_size,
validation_split = 0.2)
边栏推荐
- Leetcode systematic question brushing (I) -- linked list, stack, queue, heap
- [reading notes] for:object detection with deep learning: the definitive guide
- 2.2- data type
- Tips--对卷积的物理意义的理解
- [machine learning notes] regularization: ridge regression
- 企业微信和视频号的关联
- Electrotechnics digital circuit self study notes 1.24
- Analysis of Alibaba cloud Tianchi competition questions (in-depth learning) -- Reading Notes 1 -- competition question 1
- Tips--SCI论文写作中的小技巧
- 把MySQL8的数据库备份导入MySQL5版本中
猜你喜欢

点云处理---kd-tree

Methods, functions

MySQL and idea connection

Point cloud processing -- binary tree
![[machine learning notes] regularization: ridge regression](/img/94/9d1e126554fac0713937381253f9c9.png)
[machine learning notes] regularization: ridge regression

进程、线程、信号量和互斥锁
![[C language note sharing] custom type: structure, enumeration, Union (recommended Collection)](/img/25/4a17c260b2b506ae1224520d9b85d1.png)
[C language note sharing] custom type: structure, enumeration, Union (recommended Collection)

3D point cloud processing series - ---- PCA

wordpress提示建立数据库连接时出错

数字滤波器(四)--模拟滤波器转化为数字滤波器
随机推荐
IDEA报错Error running ‘Application‘ Command line is too long解决方案
Methods, functions
域名解析问题记录
[unity scriptable object] tutorial | using scriptable object to store object data information in unity
如何使用IDEA将项目上传到码云
点云处理--voxel filter
【p5.js】实战临摹——国际象棋盘
【Unity】三张图让你看懂ShaderGraph编辑器
[C language note sharing] custom type: structure, enumeration, Union (recommended Collection)
把MySQL8的数据库备份导入MySQL5版本中
[unity tilemap] tutorial | basic, rule tile, prefab brush, tilemap Collider
【Unity】Timeline学习笔记(七):自定义片段(Clip)
OpenMV(六)--STM32实现物体识别与手写数字识别
MySQL optimization summary
数字滤波器(四)--模拟滤波器转化为数字滤波器
leetcode系统性刷题(三)-------二叉树、二分查找
QT programming serial port assistant
[unity FPS] tutorial | using unity to make a first person character controller
数字滤波器(一)--IIR与FIR的基本结构与MATLAB实现
1.1- notes