当前位置:网站首页>Sentimentin tensorflow_ analysis_ cell
Sentimentin tensorflow_ analysis_ cell
2022-06-26 05:04:00 【Rain and dew touch the real king】
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
tf.random.set_seed(22)
np.random.seed(22)
assert tf.__version__.startswith('2.')
batchsz =128
#the most frequest words
total_words=10000
max_review_len=80
embeding_len= 100
(x_train,y_train),(x_test,y_test) = keras.datasets.imdb.load_data(num_words=total_words)
#x_train:[b,80]
#x_test:[b,80]
x_train = keras.preprocessing.sequence.pad_sequences(x_train,maxlen=max_review_len)
x_test=keras.preprocessing.sequence.pad_sequences(x_test,maxlen=max_review_len)
db_train = tf.data.Dataset.from_tensor_slices((x_train,y_train))
db_train=db_train.shuffle(1000).batch(batchsz,drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_train,y_train))
db_test=db_test.batch(batchsz,drop_remainder=True)
print('x_train shapeL:',x_train.shape,tf.reduce_max(y_train),tf.reduce_min(y_train))
print('x_test shape',x_test.shape)
class MyRNN(keras.Model):
def __init__(self,units):
super(MyRNN,self).__init__()
# [b, 64]
self.state0 = [tf.zeros([batchsz, units])]
self.state1 = [tf.zeros([batchsz, units])]
#transform text to embedding representation
#[b,80]=>[b,80,100]
self.embedding = layers.Embedding(total_words,embeding_len,
input_length=max_review_len)
#[b,80,100],h_dim:64
#RNN:cell1,cell2,cell3
#SimpleRNN
self.rnn_cell0 = layers.SimpleRNNCell(units, dropout=0.5)
self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.5)
#fc,[b,80,100]=>[b,64]=>[b,1]
self.outlayer= layers.Dense(1)
def call(self,inputs,training=None):
"""
net(x) net(x,training=True):train mode
net(x,training=False):test
:param inputs: [b,80]
:param training:
:return:
"""
#[b,80]
x = inputs
#embedding: [b,80]=>[b,80,100]
x = self.embedding(x)
#rnn cell compute
#[b,80,100]=>[b,64]
state0 = self.state0
state1=self.state1
for word in tf.unstack(x , axis=1):#word:[b,100]
#h1 = x*wxh+h0*whh
#out0:[b,64]
out0 , state0 = self.rnn_cell0(word,state0,training)
#out1:[b,64]
out1, state1=self.rnn_cell1(out0,state1)
#out:[b,64]=>[b,1]
x = self.outlayer(out1)
#p(y is pos|x)
prob = tf.sigmoid(x)
return prob
def main():
units = 64
epochs = 4
model = MyRNN(units)
model.compile(optimizer = keras.optimizers.Adam(0.001),
loss = tf.losses.BinaryCrossentropy(),
metrics=['accuracy'],experimental_run_tf_function=False)
model.fit(db_train, epochs=epochs, validation_data=db_test)
model.evaluate(db_test)
if __name__ == '__main__':
main()
边栏推荐
猜你喜欢

UWB超高精度定位系统原理图

LeetCode 19. Delete the penultimate node of the linked list

Differences between TCP and UDP

ROS notes (07) - Implementation of client and server

Use fill and fill in Matplotlib_ Between fill the blank area between functions
![[unity3d] collider assembly](/img/de/29ecf4612c540e2df715f56c31cf1a.png)
[unity3d] collider assembly

【Latex】错误类型总结(持更)

6.1 - 6.2 公鑰密碼學簡介

Stm8 MCU ADC sampling function is triggered by timer

2022.2.11
随机推荐
22.2.8
YOLOV5超参数设置与数据增强解析
Zuul implements dynamic routing
PSIM software learning ---08 call of C program block
SSH connected to win10 and reported an error: permission denied (publickey, keyboard interactive)
86.(cesium篇)cesium叠加面接收阴影效果(gltf模型)
2022.1.24
Create a binary response variable using the cut sub box operation
Multipass Chinese document - use packer to package multipass image
LISP programming language
Multipass Chinese document - share data with instances
6.1 - 6.2 Introduction à la cryptographie à clé publique
Pycharm package import error without warning
Modify the case of the string title(), upper(), lower()
Computer Vision Tools Chain
Difference between return and yield
Anti withdrawal test record
为什么许多shopify独立站卖家都在用聊天机器人?一分钟读懂行业秘密!
Astype conversion data type
Wechat applet exits the applet (navigator and api--wx.exitminiprogram)