当前位置:网站首页>Use of LSTM neural network and general neural network
Use of LSTM neural network and general neural network
2022-06-29 07:20:00 【Cochlear notes】
Be careful ,numpy Need less than 1.20.1
pip install -U numpy==1.19.2
#!/usr/bin/env python
# coding: utf-8
# In[1]:
# Copyright notice : This article is an original blog article , follow CC 4.0 BY-SA Copyright agreement , For reprint, please attach the original source link and this statement .
# Link to this article :https://blog.csdn.net/tMb8Z9Vdm66wH68VX1/article/details/90423649
import pandas as pd
import numpy as np
#matplotlib inline
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import r2_score
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import LSTM
pd.set_option('display.max_columns', None)
# Show all lines
pd.set_option('display.max_rows', None)
# ————————————————
# Copyright notice : This paper is about CSDN Blogger 「 Data pie THU」 The original article of , follow CC 4.0 BY-SA Copyright agreement , For reprint, please attach the original source link and this statement .
# Link to the original text :https://blog.csdn.net/tMb8Z9Vdm66wH68VX1/article/details/90423649
# In[2]:
df = pd.read_csv(" Test stock data .csv")
print(df.head())
# In[3]:
df.drop(['Open', 'High', 'Low', 'Close', 'Volume'], axis=1, inplace=True)
df['Date'] = pd.to_datetime(df['Date'])
#df = df.set_index(['Date'], drop=True)
df = df.set_index(['Date'], drop=True)
df.head(10)
# In[4]:
plt.figure(figsize=(10, 6))
df['Adj Close'].plot();
# In[5]:
split_date = pd.Timestamp('2019-01-01')
df = df['Adj Close']
train = df.loc[:split_date]
test = df.loc[split_date:]
plt.figure(figsize=(10, 6))
ax = train.plot()
test.plot(ax=ax)
plt.legend(['train', 'test']);
# In[6]:
train = np.array(train).reshape(-1,1)
test = np.array(test).reshape(-1,1)
# In[7]:
scaler = MinMaxScaler(feature_range=(-1, 1))
train_sc = scaler.fit_transform(train)
test_sc = scaler.transform(test)
# In[8]:
train
# In[9]:
X_train = train_sc[:-1]
y_train = train_sc[1:]
X_test = test_sc[:-1]
y_test = test_sc[1:]
# In[10]:
nn_model = Sequential()
nn_model.add(Dense(12, input_dim=1, activation='relu'))
nn_model.add(Dense(1))
nn_model.compile(loss='mean_squared_error', optimizer='adam')
early_stop = EarlyStopping(monitor='loss', patience=2, verbose=1)
history = nn_model.fit(X_train, y_train, epochs=100, batch_size=1, verbose=1, callbacks=[early_stop], shuffle=False)
# In[11]:
y_pred_test_nn = nn_model.predict(X_test)
y_train_pred_nn = nn_model.predict(X_train)
print("The R2 score on the Train set is:\t{:0.3f}".format(r2_score(y_train, y_train_pred_nn)))
print("The R2 score on the Test set is:\t{:0.3f}".format(r2_score(y_test, y_pred_test_nn)))
# In[12]:
plt.figure(figsize=(10, 6))
plt.plot(y_test, label='True')
plt.plot(y_pred_test_nn, label='NN')
plt.title("NN's Prediction")
plt.xlabel('Observation')
plt.ylabel('Adj Close Scaled')
plt.legend()
plt.show();
# In[13]:
#lmse
X_train_lmse = tf.convert_to_tensor(train_sc[:-1])
# y_train = tf.convert_to_tensor(train_sc[1:])
X_train_lmse = X_train.reshape(X_train.shape[0],1, X_train.shape[1])
X_test_lmse = X_test.reshape(X_test.shape[0],1, X_test.shape[1])
#X_train_lmse = tf.convert_to_tensor( X_train.reshape(X_train.shape[0],1, X_train.shape[1]))
lstm_model = Sequential()
lstm_model.add(LSTM(7, input_shape=(1, X_train_lmse.shape[1]), activation='relu', kernel_initializer='lecun_uniform', return_sequences=False))
lstm_model.add(Dense(1))
lstm_model.compile(loss='mean_squared_error', optimizer='adam')
early_stop = EarlyStopping(monitor='loss', patience=2, verbose=1)
history_lstm_model = lstm_model.fit(X_train_lmse, y_train, epochs=100, batch_size=1, verbose=1, shuffle=False, callbacks=[early_stop])
# In[14]:
y_pred_test_lstm = lstm_model.predict(X_test_lmse)
y_train_pred_lstm = lstm_model.predict(X_train_lmse)
print("The R2 score on the Train set is:\t{:0.3f}".format(r2_score(y_train, y_train_pred_lstm)))
print("The R2 score on the Test set is:\t{:0.3f}".format(r2_score(y_test, y_pred_test_lstm)))
# In[15]:
nn_test_mse = nn_model.evaluate(X_test, y_test, batch_size=1)
lstm_test_mse = lstm_model.evaluate(X_test_lmse, y_test, batch_size=1)
print('NN: %f'%nn_test_mse)
print('LSTM: %f'%lstm_test_mse)
# In[16]:
plt.figure(figsize=(10, 6))
plt.plot(y_test, label='True')
plt.plot(y_pred_test_lstm, label='LSTM')
plt.title("LSTM's Prediction")
plt.xlabel('Observation')
plt.ylabel('Adj Close scaled')
plt.legend()
plt.show()
边栏推荐
- Effective methods for construction enterprises to select smart construction sites
- [software test] interface - Basic test process
- mmclassification安装与调试
- Spark RDD case: Statistics of daily new users
- Redis of NoSQL database (I): Installation & Introduction
- Using IPv6 to access remote desktop through public network
- Markdown 技能树(8):代码块
- Markdown 技能树(4):链接
- Oscilloscope symbols
- Class differences of QT processing image data (qpixmap, qimage, qpicture)
猜你喜欢

Using IPv6 to access remote desktop through public network

Idea use

Exploring the depth of objects in JVM series
如何看待软件测试培训?你需要培训吗?

消息队列之通过队列批处理退款订单

关联性——相关性分析

QT program packaging and publishing windeployqt tool

多模态 —— Learnable pooling with Context Gating for video classification

YGG cooperated with Web3 platform leader to empower the creative community with Dao tools and resources

NoSQL数据库之Redis(二):Redis配置文件介绍
随机推荐
Markdown 技能树(5):图片
What are the conditions for a high-quality public chain?
示波器 符号
电子商务盛行,怎么提高商店转换率?
Qt QLineEdit详解
Introduction to NoSQL database
LSTM神经网络和普通神经网络的使用
uva11825
Markdown 技能树(4):链接
QT container class
Differences between JSON objects and JSON strings
YGG cooperated with Web3 platform leader to empower the creative community with Dao tools and resources
Idea integrated code cloud
WebRTC系列-网络传输之8-连通性检测
When the soft keyboard appears, it makes my EditText field lose focus
Final summary spark
Summary of some new datasets proposed by cvpr2021
Markdown 技能树(3):标题
Datatables屏蔽报错弹窗
Webrtc series - 8-connectivity detection for network transmission