当前位置:网站首页>Lstms in tensorflow_ Cell actual combat
Lstms in tensorflow_ Cell actual combat
2022-06-26 05:04:00 【Rain and dew touch the real king】
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
tf.random.set_seed(22)
np.random.seed(22)
assert tf.__version__.startswith('2.')
batchsz = 128
# the most frequest words
total_words = 10000
max_review_len = 80
embedding_len = 100
(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=total_words)
# x_train:[b, 80]
# x_test: [b, 80]
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_review_len)
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_review_len)
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(1000).batch(batchsz, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz, drop_remainder=True)
print('x_train shape:', x_train.shape, tf.reduce_max(y_train), tf.reduce_min(y_train))
print('x_test shape:', x_test.shape)
class MyRNN(keras.Model):
def __init__(self, units):
super(MyRNN, self).__init__()
# [b, 64]
self.state0 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
self.state1 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
# transform text to embedding representation
# [b, 80] => [b, 80, 100]
self.embedding = layers.Embedding(total_words, embedding_len,
input_length=max_review_len)
# [b, 80, 100] , h_dim: 64
# RNN: cell1 ,cell2, cell3
# SimpleRNN
# self.rnn_cell0 = layers.SimpleRNNCell(units, dropout=0.5)
# self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.5)
self.rnn_cell0 = layers.LSTMCell(units, dropout=0.5)
self.rnn_cell1 = layers.LSTMCell(units, dropout=0.5)
# fc, [b, 80, 100] => [b, 64] => [b, 1]
self.outlayer = layers.Dense(1)
def call(self, inputs, training=None):
"""
net(x) net(x, training=True) :train mode
net(x, training=False): test
:param inputs: [b, 80]
:param training:
:return:
"""
# [b, 80]
x = inputs
# embedding: [b, 80] => [b, 80, 100]
x = self.embedding(x)
# rnn cell compute
# [b, 80, 100] => [b, 64]
state0 = self.state0
state1 = self.state1
for word in tf.unstack(x, axis=1): # word: [b, 100]
# h1 = x*wxh+h0*whh
# out0: [b, 64]
out0, state0 = self.rnn_cell0(word, state0, training)
# out1: [b, 64]
out1, state1 = self.rnn_cell1(out0, state1, training)
# out: [b, 64] => [b, 1]
x = self.outlayer(out1)
# p(y is pos|x)
prob = tf.sigmoid(x)
return prob
def main():
units = 64
epochs = 4
import time
t0 = time.time()
model = MyRNN(units)
model.compile(optimizer = keras.optimizers.Adam(0.001),
loss = tf.losses.BinaryCrossentropy(),
metrics=['accuracy'],experimental_run_tf_function=False)
model.fit(db_train, epochs=epochs, validation_data=db_test)
model.evaluate(db_test)
t1 = time.time()
print('total time cost:', t1-t0)# 64.3 seconds, 83.4%
if __name__ == '__main__':
main()
边栏推荐
- DBeaver 安装及配置离线驱动
- Multipass Chinese document - use multipass service to authorize the client
- 6.1 - 6.2 Introduction à la cryptographie à clé publique
- Multipass中文文档-设置驱动
- [unity3d] collider assembly
- Multipass中文文档-远程使用Multipass
- date_ Range creation date range freq parameter value table and creation example
- Multipass Chinese document - setup driver
- Ai+ remote sensing: releasing the value of each pixel
- Modify the case of the string title(), upper(), lower()
猜你喜欢

UWB超高精度定位系统架构图

5. < tag stack and general problems > supplement: lt.946 Verify the stack sequence (the same as the push in and pop-up sequence of offer 31. stack)
![[unity3d] human computer interaction input](/img/4d/47f6d40bb82400fe9c6d624c8892f7.png)
[unity3d] human computer interaction input

File upload and security dog

-Discrete Mathematics - Analysis of final exercises

A company crawling out of its grave

文件上传与安全狗

Guanghetong and anti international bring 5g R16 powerful performance to the AI edge computing platform based on NVIDIA Jetson Xavier nx

Statsmodels Library -- linear regression model

2.< tag-动态规划和常规问题>lt.343. 整数拆分
随机推荐
Codeforces Round #802 (Div. 2)(A-D)
Sklearn Library -- linear regression model
Genius makers: lone Rangers, technology giants and AI | ten years of the rise of in-depth learning
Multipass中文文档-提高挂载性能
Multipass中文文档-与实例共享数据
Mise en œuvre du routage dynamique par zuul
86.(cesium篇)cesium叠加面接收阴影效果(gltf模型)
YOLOV5超参数设置与数据增强解析
Interpretation of yolov5 training results
YOLOV5训练结果的解释
Hash problem
天才制造者:独行侠、科技巨头和AI|深度学习崛起十年
ThreadPoolExecutor implements file uploading and batch inserting data
One of token passing between microservices @feign's token passing
How can the intelligent transformation path of manufacturing enterprises be broken due to talent shortage and high cost?
2.22.2.14
NVM installation and use and NPM package installation failure record
2. < tag dynamic programming and conventional problems > lt.343 integer partition
Day3 data type and Operator jobs
5. < tag stack and general problems > supplement: lt.946 Verify the stack sequence (the same as the push in and pop-up sequence of offer 31. stack)