当前位置:网站首页>Climbing the pit of traffic flow prediction (II): the simplest LSTM predicts traffic flow using tensorflow2
Climbing the pit of traffic flow prediction (II): the simplest LSTM predicts traffic flow using tensorflow2
2022-07-29 04:48:00 【__ Meursault__】
Speaking of time series prediction , I think I must first think of RNN, Then think of LSTM,LSTM Let's not talk about the principle , There are many related articles on the Internet .
Use tensorflow2.0 To achieve the prediction
Have to say tensorflow2.0 It's so delicious , Too simple. , If you really have hands
stay tensorflow You only need to call already tensorflow Of LSTM Just modules , Consider the following code
from tensorflow.keras.layers import Dense,LSTM,Dropout
model = tf.keras.Sequential([
LSTM(80, return_sequences=True),
Dropout(0.2),
LSTM(80),
Dropout(0.2),
Dense(1)
])
model.compile(optimizer='adam',
loss='mse',)
This creates a 2 layer LSTM, Each layer 80 One neuron ; At the same time, added Droopout Function to prevent overfitting ; Use adam Activation function ; Use mse Neural network as loss error . Really fried chicken is simple .
The main problem is data processing , To do time series prediction , The principle should be before use n It's time to predict the next time , That is, the data trained by the model should be data like the following figure 
So processing data is difficult .
The data I use below is in Last article UK site data mentioned in . Other data are similar .
Baidu SkyDrive : https://pan.baidu.com/s/19vKN2eZZPbOg36YEWts4aQ
password 4uh7
When importing data , I don't know why if there is a red column , It will prompt the error , So I deleted this data directly , This column of data has no impact on the forecast 
Then through the following code, you can get a containing , date 、 Data of traffic
f = pd.read_csv('..\Desktop\AE86.csv')
# Set column labels again
def set_columns():
columns = []
for i in f.loc[2]:
columns.append(i.strip())
return columns
f.columns = set_columns()
f.drop([0,1,2], inplace = True)
# data Contains the columns to operate on
data = pd.DataFrame()
# Which line of data do you want to leave , Add it here to data in
data['datetime'] = f['Local Date']+' '+f['Local Time']
data['total_flow'] = f['Total Carriageway Flow']
# data['speed'] = f['Speed Value'] Speed is not used in this article
data['datetime'] = pd.to_datetime(data['datetime'])
data['month'] = data['datetime'].apply(lambda date: date.month)
data['day'] = data['datetime'].apply(lambda date: date.day)
data['hour'] = data['datetime'].apply(lambda date:date.hour)
data['minute'] = data['datetime'].apply(lambda date: date.minute)
# Data format
data['total_flow'] = np.array(data['total_flow']).astype(np.float64)
The processed data are as follows 
Then it is to divide the training set and the test set , normalization
# The first day of January 25 Index value of the first time of the day
d25 = data.query('day==25').index[0]
# Training set 2211 Data ,2018 The first three weeks of January
train_set = data.iloc[:d25,1:2]
# Detection set 669 Data ,2018 Last week of
test_set = data.iloc[d25:,1:2]
# normalization
sc = MinMaxScaler(feature_range=(0, 1))
train_set_sc = sc.fit_transform(train_set)
test_set_sc = sc.transform(test_set)
Here's how to create LSTM Input data for , With time_step=5 Is the prediction interval , That is, before use 5 Time period , Predict the next time period
time_step = 5
# according to time_step Divided time step
x_train = []
y_train = []
x_test = []
y_test = []
for i in range(time_step, len(train_set_sc)):
x_train.append(train_set_sc[i - time_step:i])
y_train.append(train_set_sc[i:i + 1])
for i in range(time_step, len(test_set_sc)):
x_test.append(test_set_sc[i - time_step:i])
y_test.append(test_set_sc[i:i + 1])
x_test, y_test = np.array(x_test), np.array(y_test)
# randomization , This part can not
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
# To array Format
x_train, y_train = np.array(x_train), np.array(y_train)
x_test, y_test = np.array(x_test), np.array(y_test)
x_train = np.reshape(x_train, (x_train.shape[0], time_step, 1))
x_test = np.reshape(x_test, (x_test.shape[0], time_step, 1))
The following is building the model , forecast , error analysis , Visualization and so on
The overall code is as follows
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense,LSTM,Dropout,Flatten
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
import math
from matplotlib.font_manager import FontProperties # You can use Chinese when drawing
f = pd.read_csv('..\Desktop\AE86.csv')
# Set column labels again
def set_columns():
columns = []
for i in f.loc[2]:
columns.append(i.strip())
return columns
f.columns = set_columns()
f.drop([0,1,2], inplace = True)
# data Contains the columns to operate on
data = pd.DataFrame()
data['datetime'] = f['Local Date']+' '+f['Local Time']
data['total_flow'] = f['Total Carriageway Flow']
# data['speed'] = f['Speed Value']
data['datetime'] = pd.to_datetime(data['datetime'])
data['month'] = data['datetime'].apply(lambda date: date.month)
data['day'] = data['datetime'].apply(lambda date: date.day)
data['hour'] = data['datetime'].apply(lambda date:date.hour)
data['minute'] = data['datetime'].apply(lambda date: date.minute)
# Data format
data['total_flow'] = np.array(data['total_flow']).astype(np.float64)
# The first day of January 25 Index value of the first time of the day
d25 = data.query('day==25').index[0]
# Training set 2211 Data ,2018 The first three weeks of January
train_set = data.iloc[:d25,1:2]
# Detection set 669 Data ,2018 Last week of
test_set = data.iloc[d25:,1:2]
# normalization
sc = MinMaxScaler(feature_range=(0, 1))
train_set_sc = sc.fit_transform(train_set)
test_set_sc = sc.transform(test_set)
# according to time_step Divided time step
time_step = 5
x_train = []
y_train = []
x_test = []
y_test = []
for i in range(time_step, len(train_set_sc)):
x_train.append(train_set_sc[i - time_step:i])
y_train.append(train_set_sc[i:i + 1])
for i in range(time_step, len(test_set_sc)):
x_test.append(test_set_sc[i - time_step:i])
y_test.append(test_set_sc[i:i + 1])
x_test, y_test = np.array(x_test), np.array(y_test)
# randomization , This part can not
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
# To array Format
x_train, y_train = np.array(x_train), np.array(y_train)
x_test, y_test = np.array(x_test), np.array(y_test)
x_train = np.reshape(x_train, (x_train.shape[0], time_step, 1))
x_test = np.reshape(x_test, (x_test.shape[0], time_step, 1))
# LSTM Model
model = tf.keras.Sequential([
LSTM(80, return_sequences=True),
Dropout(0.2),
LSTM(80),
Dropout(0.2),
Dense(1)
])
model.compile(optimizer='adam',
loss='mse',)
# Training models , among epochs,batch_size You can change it yourself
history = model.fit(x_train, y_train,
epochs=5,
validation_data=(x_test, y_test))
# Model to predict
pre_flow = model.predict(x_test)
# Anti normalization
pre_flow = sc.inverse_transform(pre_flow)
real_flow = sc.inverse_transform(y_test.reshape(y_test.shape[0], 1))
# Calculation error
mse = mean_squared_error(pre_flow, real_flow)
rmse = math.sqrt(mean_squared_error(pre_flow, real_flow))
mae = mean_absolute_error(pre_flow, real_flow)
print(' Mean square error ---', mse)
print(' Root mean square error ---', rmse)
print(' Mean absolute error --', mae)
# Draw the prediction results
font_set = FontProperties(fname=r"C:\Windows\Fonts\simsun.ttc", size=15) # The Chinese font is in Song Dynasty ,15 Number
plt.figure(figsize=(15,10))
plt.plot(real_flow, label='Real_Flow', color='r', )
plt.plot(pre_flow, label='Pre_Flow')
plt.xlabel(' Test sequence ', fontproperties=font_set)
plt.ylabel(' traffic flow / car ', fontproperties=font_set)
plt.legend()
# Predict stored pictures
# plt.savefig('...\Desktop\123.jpg')
The above code is the simplest , Just use traffic , At the same time, a single node performs traffic prediction .
You can also use speed , Occupancy and other information , Add to the model to predict the flow . It's hard to be serious , But if you just deal with it , Provide a thought :
Several other features can also be treated according to time_step=5, division , Directly into the model , Just add one at the last layer of the model Flatten layer ( Straighten all data into one dimension ), In this way, you can say ” This paper considers , Traffic 、 Speed 、 Lane occupancy and other factors , Compared with previous articles, it has significant improvements “
边栏推荐
- [QT learning notes] * insert pictures in the window
- un7.28:redis客户端常用命令。
- Command line interactive tools (latest version) inquirer practical tutorial
- 数据湖:分布式开源处理引擎Spark
- [C language] PTA 7-47 binary leading zero
- I++ and ++i details
- Tower of Hanoi classic recursion problem (C language implementation)
- Makefile+make Basics
- pulsar起client客户端时(client,producer,consumer)各个配置
- Pycharm reports an error when connecting to the virtual machine database
猜你喜欢

GCC Basics

在线教育的推荐系统

C language implementation of three chess

Implementation of img responsive pictures (including the usage of srcset attribute and sizes attribute, and detailed explanation of device pixel ratio)

mujoco和mujoco_py安装以及解决libXcursor.so.1:NO such dictionary

Flutter实战-请求封装(二)之dio

Pycharm reports an error when connecting to the virtual machine database

SSM integration, addition, deletion, modification and query

网络之以太网

Use of construction methods
随机推荐
UE 在场景或UMG中播放视频
[C language] PTA 7-47 binary leading zero
Vscode configuration makefile compilation
[C language] PTA 7-91 output leap year
[express connection to MySQL database]
Pyqt5 learning pit encounter and pit drainage (3) background picture coverage button style and check button status
Download addresses of various versions of MySQL and multi version coexistence installation
Record the Niua packaging deployment project
手机工作室网络如何组建?
JVM (heap and stack) memory allocation
PHP判断用户是否已经登录,如果登录则显示首页,如果未登录则进入登录页面或注册页面
【Express连接MySQL数据库】
Mujoco and mujoco_ Install libxcursor.so 1:NO such dictionary
iOS面试准备 - 其他篇
Introduction to auto.js script development
2022杭电多校联赛第四场 题解
Flink+Iceberg环境搭建及生产问题处理
Review key points and data sorting of information metrology in the second semester of 2022 (teacher zhaorongying of Wuhan University)
un7.28:redis客户端常用命令。
Several simple and difficult OJ problems with sequential force deduction