当前位置:网站首页>25.时间序列预测实战
25.时间序列预测实战
2022-08-04 07:03:00 【派大星的最爱海绵宝宝】
时间序列预测实战
[b,50,1],b为1时,可以理解为只送入一条曲线,每一条曲线有50点的数据,每个点数据都是实数。
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
start是随机的,是每次开始的起点。
我们需要完成的功能是,对于一条曲线,给出红色部分时,要求预测出蓝色部分曲线。
x是给定的0到48的部分,y需要预测出1到49的部分。
Train
out[b,seq_len,h]
h[b,1,h]
hidden_prev是h0,最开始是一个batch,一层,h是10。
我们将output和y之间进行一个MSE求误差,根据这个误差进行网络的更新。
hidden_prev=torch.zeros(1,1,hidden_size)
for iter in range(6000):
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
output,hidden_prev=model(x,hidden_prev)
hidden_prev=hidden_prev.detach()
loss=criteon(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter %100 ==0:
print("Iteration:{} loss:{}".format(iter,loss.item()))
Test
先将预测值做一个空的数组。
x[1,seq,1]。
每次的input等于pred出来的点,每次只画一个点,最后进行串联。
predictions=[]
input=x[:,0,:]
for _ in range(x.shape[1]):
input=input.view(1,1,1)
(pred,hidden_prev)=model(input,hidden_prev)
input=pred
predictions.append(pred.detach().numpy().ravel()[0])
结果


代码
import numpy
import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
num_time_steps = 50
input_size=1
hidden_size=16
output_size=1
lr = 0.01
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.rnn=nn.RNN(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True
)
self.linear=nn.Linear(hidden_size,output_size)
def forward(self,x,hidden_prev):
out,hidden_prev=self.rnn(x,hidden_prev)
#[1,seq,h]->[seq,h]
out=out.view(-1,hidden_size)
out=self.linear(out) #[seq,h]->[seq,1]
out=out.unsqueeze(dim=0) #->[1,seq,-1]
return out,hidden_prev
def main():
model=Net()
criteon=nn.MSELoss()
optimizer=optim.Adam(model.parameters(),lr)
hidden_prev=torch.zeros(1,1,hidden_size)
for iter in range(6000):
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
output,hidden_prev=model(x,hidden_prev)
hidden_prev=hidden_prev.detach()
loss=criteon(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter %100 ==0:
print("Iteration:{} loss:{}".format(iter,loss.item()))
predictions=[]
input=x[:,0,:]
for _ in range(x.shape[1]):
input=input.view(1,1,1)
(pred,hidden_prev)=model(input,hidden_prev)
input=pred
predictions.append(pred.detach().numpy().ravel()[0])
x=x.data.numpy().ravel()
y=y.data.numpy()
plt.scatter(time_step[:-1],x.ravel(),s=90)
plt.plot(time_step[:-1],x.ravel())
plt.scatter(time_step[1:],predictions)
plt.show()
if __name__ == '__main__':
main()
边栏推荐
猜你喜欢

Verilog“七宗罪”

简析强制缓存和协商缓存

data:image/jpg; base64 format data is converted to image

Distributed Computing Experiment 3 PRC-based Book Information Management System

mysql基础(4)

一天学会JDBC03:Statement的用法

在线问题反馈模块实战(十八):实现excel台账文件记录批量导入功能

Detailed ResNet: What problem is ResNet solving?

两日总结四

MMDeploy部署实战系列【第四章】:onnx,tensorrt模型推理
随机推荐
ExoPlayer添加Ffmpeg扩展实现软解功能
两日总结六
串口监听 - 软件方案
带你了解一下PHP搭建的电商商城系统
Redis非关系型数据库
【学习笔记】AGC036
专题讲座7 计算几何 学习心得
redis stream 实现消息队列
The national vocational skills contest competition of network security emergency response
The sorting algorithm including selection, bubble, and insertion
2022的七夕,奉上7个精美的表白代码,同时教大家改源码快速自用
unity3d-Animation&&Animator接口(基本使用)
一天学会JDBC03:Statement的用法
npm包发布与迭代
分布式计算实验3 基于PRC的书籍信息管理系统
CSDN21天学习挑战赛——day1 正则表达式大总结
(19)[系统调用]SSTD hook 阻止关闭
MySQL外键(详解)
Error occurred while trying to proxy request项目突然起不来了
Error ER_NOT_SUPPORTED_AUTH_MODE Client does not support authentication protocol requested by serv