当前位置:网站首页>Lstms in tensorflow_ Cell actual combat
Lstms in tensorflow_ Cell actual combat
2022-06-26 05:04:00 【Rain and dew touch the real king】
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
tf.random.set_seed(22)
np.random.seed(22)
assert tf.__version__.startswith('2.')
batchsz = 128
# the most frequest words
total_words = 10000
max_review_len = 80
embedding_len = 100
(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=total_words)
# x_train:[b, 80]
# x_test: [b, 80]
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_review_len)
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_review_len)
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(1000).batch(batchsz, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz, drop_remainder=True)
print('x_train shape:', x_train.shape, tf.reduce_max(y_train), tf.reduce_min(y_train))
print('x_test shape:', x_test.shape)
class MyRNN(keras.Model):
def __init__(self, units):
super(MyRNN, self).__init__()
# [b, 64]
self.state0 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
self.state1 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
# transform text to embedding representation
# [b, 80] => [b, 80, 100]
self.embedding = layers.Embedding(total_words, embedding_len,
input_length=max_review_len)
# [b, 80, 100] , h_dim: 64
# RNN: cell1 ,cell2, cell3
# SimpleRNN
# self.rnn_cell0 = layers.SimpleRNNCell(units, dropout=0.5)
# self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.5)
self.rnn_cell0 = layers.LSTMCell(units, dropout=0.5)
self.rnn_cell1 = layers.LSTMCell(units, dropout=0.5)
# fc, [b, 80, 100] => [b, 64] => [b, 1]
self.outlayer = layers.Dense(1)
def call(self, inputs, training=None):
"""
net(x) net(x, training=True) :train mode
net(x, training=False): test
:param inputs: [b, 80]
:param training:
:return:
"""
# [b, 80]
x = inputs
# embedding: [b, 80] => [b, 80, 100]
x = self.embedding(x)
# rnn cell compute
# [b, 80, 100] => [b, 64]
state0 = self.state0
state1 = self.state1
for word in tf.unstack(x, axis=1): # word: [b, 100]
# h1 = x*wxh+h0*whh
# out0: [b, 64]
out0, state0 = self.rnn_cell0(word, state0, training)
# out1: [b, 64]
out1, state1 = self.rnn_cell1(out0, state1, training)
# out: [b, 64] => [b, 1]
x = self.outlayer(out1)
# p(y is pos|x)
prob = tf.sigmoid(x)
return prob
def main():
units = 64
epochs = 4
import time
t0 = time.time()
model = MyRNN(units)
model.compile(optimizer = keras.optimizers.Adam(0.001),
loss = tf.losses.BinaryCrossentropy(),
metrics=['accuracy'],experimental_run_tf_function=False)
model.fit(db_train, epochs=epochs, validation_data=db_test)
model.evaluate(db_test)
t1 = time.time()
print('total time cost:', t1-t0)# 64.3 seconds, 83.4%
if __name__ == '__main__':
main()
边栏推荐
- Record a circular reference problem
- 【Unity3D】刚体组件Rigidbody
- Generalized linear model (logistic regression, Poisson regression)
- [IDE(ImageBed)]Picgo+Typora+aliyunOSS部署博客图床(2022.6)
- 86.(cesium篇)cesium叠加面接收阴影效果(gltf模型)
- LeetCode 19. Delete the penultimate node of the linked list
- 天才制造者:獨行俠、科技巨頭和AI|深度學習崛起十年
- Multipass中文文档-使用Packer打包Multipass镜像
- Créateur de génie: cavalier solitaire, magnat de la technologie et ai | dix ans d'apprentissage profond
- How can the intelligent transformation path of manufacturing enterprises be broken due to talent shortage and high cost?
猜你喜欢

文件上传与安全狗

Using requests library and re library to crawl web pages

Schematic diagram of UWB ultra high precision positioning system

2.22.2.14
![[unity3d] collider assembly](/img/de/29ecf4612c540e2df715f56c31cf1a.png)
[unity3d] collider assembly

How can the intelligent transformation path of manufacturing enterprises be broken due to talent shortage and high cost?

torchvision_transform(图像增强)
![C# 39. string类型和byte[]类型相互转换(实测)](/img/33/046aef4e0c1d7c0c0d60c28e707546.png)
C# 39. string类型和byte[]类型相互转换(实测)

一个从坟墓里爬出的公司

Multipass Chinese document - remote use of multipass
随机推荐
6.1 - 6.2 公钥密码学简介
Sklearn Library -- linear regression model
【Unity3D】碰撞体组件Collider
Multipass Chinese document - use instance command alias
Status of processes and communication between processes
Simple application of KMP
[greedy college] recommended system engineer training plan
pycharm 导包错误没有警告
Condition query
Using requests library and re library to crawl web pages
Multipass中文文档-使用Multipass服务授权客户端
UWB ultra high precision positioning system architecture
A company crawling out of its grave
Problem follow up - PIP source change
Differences between TCP and UDP
86. (cesium chapter) cesium overlay surface receiving shadow effect (gltf model)
Why do many Shopify independent station sellers use chat robots? Read industry secrets in one minute!
0622 horse palm fell 9%
LeetCode 19. 删除链表的倒数第 N 个结点
[greedy college] Figure neural network advanced training camp