当前位置:网站首页>Lstms in tensorflow_ Cell actual combat
Lstms in tensorflow_ Cell actual combat
2022-06-26 05:04:00 【Rain and dew touch the real king】
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
tf.random.set_seed(22)
np.random.seed(22)
assert tf.__version__.startswith('2.')
batchsz = 128
# the most frequest words
total_words = 10000
max_review_len = 80
embedding_len = 100
(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=total_words)
# x_train:[b, 80]
# x_test: [b, 80]
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_review_len)
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_review_len)
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(1000).batch(batchsz, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz, drop_remainder=True)
print('x_train shape:', x_train.shape, tf.reduce_max(y_train), tf.reduce_min(y_train))
print('x_test shape:', x_test.shape)
class MyRNN(keras.Model):
def __init__(self, units):
super(MyRNN, self).__init__()
# [b, 64]
self.state0 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
self.state1 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
# transform text to embedding representation
# [b, 80] => [b, 80, 100]
self.embedding = layers.Embedding(total_words, embedding_len,
input_length=max_review_len)
# [b, 80, 100] , h_dim: 64
# RNN: cell1 ,cell2, cell3
# SimpleRNN
# self.rnn_cell0 = layers.SimpleRNNCell(units, dropout=0.5)
# self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.5)
self.rnn_cell0 = layers.LSTMCell(units, dropout=0.5)
self.rnn_cell1 = layers.LSTMCell(units, dropout=0.5)
# fc, [b, 80, 100] => [b, 64] => [b, 1]
self.outlayer = layers.Dense(1)
def call(self, inputs, training=None):
"""
net(x) net(x, training=True) :train mode
net(x, training=False): test
:param inputs: [b, 80]
:param training:
:return:
"""
# [b, 80]
x = inputs
# embedding: [b, 80] => [b, 80, 100]
x = self.embedding(x)
# rnn cell compute
# [b, 80, 100] => [b, 64]
state0 = self.state0
state1 = self.state1
for word in tf.unstack(x, axis=1): # word: [b, 100]
# h1 = x*wxh+h0*whh
# out0: [b, 64]
out0, state0 = self.rnn_cell0(word, state0, training)
# out1: [b, 64]
out1, state1 = self.rnn_cell1(out0, state1, training)
# out: [b, 64] => [b, 1]
x = self.outlayer(out1)
# p(y is pos|x)
prob = tf.sigmoid(x)
return prob
def main():
units = 64
epochs = 4
import time
t0 = time.time()
model = MyRNN(units)
model.compile(optimizer = keras.optimizers.Adam(0.001),
loss = tf.losses.BinaryCrossentropy(),
metrics=['accuracy'],experimental_run_tf_function=False)
model.fit(db_train, epochs=epochs, validation_data=db_test)
model.evaluate(db_test)
t1 = time.time()
print('total time cost:', t1-t0)# 64.3 seconds, 83.4%
if __name__ == '__main__':
main()
边栏推荐
- Status of processes and communication between processes
- 2022.1.24
- LeetCode 19. Delete the penultimate node of the linked list
- Sklearn Library -- linear regression model
- [unity3d] rigid body component
- 6.1 - 6.2 公鑰密碼學簡介
- YOLOV5训练结果的解释
- Simple application of KMP
- 天才制造者:独行侠、科技巨头和AI|深度学习崛起十年
- Multipass中文文档-使用Packer打包Multipass镜像
猜你喜欢

Schematic diagram of UWB ultra high precision positioning system

Using requests library and re library to crawl web pages

Illustration of ONEFLOW's learning rate adjustment strategy

AD教程系列 | 4 - 创建集成库文件

Stm8 MCU ADC sampling function is triggered by timer

LeetCode 19. 删除链表的倒数第 N 个结点

-Discrete Mathematics - Analysis of final exercises
![[unity3d] rigid body component](/img/57/344aae65e4ac6a7d44b235584f95d1.png)
[unity3d] rigid body component

ROS 笔记(07)— 客户端 Client 和服务端 Server 的实现

6.1 - 6.2 introduction to public key cryptography
随机推荐
2022.2.17
Large numbers (C language)
Multipass Chinese document - use multipass service to authorize the client
Sklearn Library -- linear regression model
Zuul 實現動態路由
[geek] product manager training camp
ModuleNotFoundError: No module named ‘numpy‘
Genius makers: lone Rangers, technology giants and AI | ten years of the rise of in-depth learning
图解OneFlow的学习率调整策略
Multipass Chinese document - remove instance
[greedy college] recommended system engineer training plan
Using requests library and re library to crawl web pages
Multipass Chinese document - share data with instances
Multipass Chinese document - remote use of multipass
GD32F3x0 官方PWM驱动正频宽偏小(定时不准)的问题
Muke.com actual combat course
Using Matplotlib to add an external image at the canvas level
Guanghetong and anti international bring 5g R16 powerful performance to the AI edge computing platform based on NVIDIA Jetson Xavier nx
ModuleNotFoundError: No module named ‘numpy‘
[unity3d] human computer interaction input