当前位置:网站首页>25.时间序列预测实战
25.时间序列预测实战
2022-08-04 07:03:00 【派大星的最爱海绵宝宝】
时间序列预测实战
[b,50,1],b为1时,可以理解为只送入一条曲线,每一条曲线有50点的数据,每个点数据都是实数。
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
start是随机的,是每次开始的起点。
我们需要完成的功能是,对于一条曲线,给出红色部分时,要求预测出蓝色部分曲线。
x是给定的0到48的部分,y需要预测出1到49的部分。
Train
out[b,seq_len,h]
h[b,1,h]
hidden_prev是h0,最开始是一个batch,一层,h是10。
我们将output和y之间进行一个MSE求误差,根据这个误差进行网络的更新。
hidden_prev=torch.zeros(1,1,hidden_size)
for iter in range(6000):
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
output,hidden_prev=model(x,hidden_prev)
hidden_prev=hidden_prev.detach()
loss=criteon(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter %100 ==0:
print("Iteration:{} loss:{}".format(iter,loss.item()))
Test
先将预测值做一个空的数组。
x[1,seq,1]。
每次的input等于pred出来的点,每次只画一个点,最后进行串联。
predictions=[]
input=x[:,0,:]
for _ in range(x.shape[1]):
input=input.view(1,1,1)
(pred,hidden_prev)=model(input,hidden_prev)
input=pred
predictions.append(pred.detach().numpy().ravel()[0])
结果
代码
import numpy
import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
num_time_steps = 50
input_size=1
hidden_size=16
output_size=1
lr = 0.01
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.rnn=nn.RNN(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True
)
self.linear=nn.Linear(hidden_size,output_size)
def forward(self,x,hidden_prev):
out,hidden_prev=self.rnn(x,hidden_prev)
#[1,seq,h]->[seq,h]
out=out.view(-1,hidden_size)
out=self.linear(out) #[seq,h]->[seq,1]
out=out.unsqueeze(dim=0) #->[1,seq,-1]
return out,hidden_prev
def main():
model=Net()
criteon=nn.MSELoss()
optimizer=optim.Adam(model.parameters(),lr)
hidden_prev=torch.zeros(1,1,hidden_size)
for iter in range(6000):
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
output,hidden_prev=model(x,hidden_prev)
hidden_prev=hidden_prev.detach()
loss=criteon(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter %100 ==0:
print("Iteration:{} loss:{}".format(iter,loss.item()))
predictions=[]
input=x[:,0,:]
for _ in range(x.shape[1]):
input=input.view(1,1,1)
(pred,hidden_prev)=model(input,hidden_prev)
input=pred
predictions.append(pred.detach().numpy().ravel()[0])
x=x.data.numpy().ravel()
y=y.data.numpy()
plt.scatter(time_step[:-1],x.ravel(),s=90)
plt.plot(time_step[:-1],x.ravel())
plt.scatter(time_step[1:],predictions)
plt.show()
if __name__ == '__main__':
main()
边栏推荐
猜你喜欢
随机推荐
Mac安装PHP开发环境
[想要访问若依后台]若依框架报错401请求访问:error认证失败,无法访问系统资源
Error EPERM operation not permitted, mkdir ‘Dsoftwarenodejsnode_cache_cacach两种解决办法
SQL如何从字符串截取指定字符(LEFT、MID、RIGHT三大函数)
Promise.all 使用方法
分布式计算实验3 基于PRC的书籍信息管理系统
CSDN21天学习挑战赛——day1 正则表达式大总结
C语言实现-华为太空人手表
SQL存储过程详解
Detailed ResNet: What problem is ResNet solving?
Distributed Computing MapReduce | Spark Experiment
QT + msvc2017编译器
一天搞定JDBC01:连接数据库并执行sql语句
中断和异常的处理与抢占式多任务
反射与枚举
powershell和cmd对比
在线问题反馈模块实战(十八):实现excel台账文件记录批量导入功能
【论文笔记】—低照度图像增强—Supervised—RetinexNet—2018-BMVC
C语言指针
将回调函数转为Flow