当前位置:网站首页>25.时间序列预测实战
25.时间序列预测实战
2022-08-04 07:03:00 【派大星的最爱海绵宝宝】
时间序列预测实战
[b,50,1],b为1时,可以理解为只送入一条曲线,每一条曲线有50点的数据,每个点数据都是实数。
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
start是随机的,是每次开始的起点。
我们需要完成的功能是,对于一条曲线,给出红色部分时,要求预测出蓝色部分曲线。
x是给定的0到48的部分,y需要预测出1到49的部分。
Train
out[b,seq_len,h]
h[b,1,h]
hidden_prev是h0,最开始是一个batch,一层,h是10。
我们将output和y之间进行一个MSE求误差,根据这个误差进行网络的更新。
hidden_prev=torch.zeros(1,1,hidden_size)
for iter in range(6000):
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
output,hidden_prev=model(x,hidden_prev)
hidden_prev=hidden_prev.detach()
loss=criteon(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter %100 ==0:
print("Iteration:{} loss:{}".format(iter,loss.item()))
Test
先将预测值做一个空的数组。
x[1,seq,1]。
每次的input等于pred出来的点,每次只画一个点,最后进行串联。
predictions=[]
input=x[:,0,:]
for _ in range(x.shape[1]):
input=input.view(1,1,1)
(pred,hidden_prev)=model(input,hidden_prev)
input=pred
predictions.append(pred.detach().numpy().ravel()[0])
结果


代码
import numpy
import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
num_time_steps = 50
input_size=1
hidden_size=16
output_size=1
lr = 0.01
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.rnn=nn.RNN(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True
)
self.linear=nn.Linear(hidden_size,output_size)
def forward(self,x,hidden_prev):
out,hidden_prev=self.rnn(x,hidden_prev)
#[1,seq,h]->[seq,h]
out=out.view(-1,hidden_size)
out=self.linear(out) #[seq,h]->[seq,1]
out=out.unsqueeze(dim=0) #->[1,seq,-1]
return out,hidden_prev
def main():
model=Net()
criteon=nn.MSELoss()
optimizer=optim.Adam(model.parameters(),lr)
hidden_prev=torch.zeros(1,1,hidden_size)
for iter in range(6000):
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
output,hidden_prev=model(x,hidden_prev)
hidden_prev=hidden_prev.detach()
loss=criteon(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter %100 ==0:
print("Iteration:{} loss:{}".format(iter,loss.item()))
predictions=[]
input=x[:,0,:]
for _ in range(x.shape[1]):
input=input.view(1,1,1)
(pred,hidden_prev)=model(input,hidden_prev)
input=pred
predictions.append(pred.detach().numpy().ravel()[0])
x=x.data.numpy().ravel()
y=y.data.numpy()
plt.scatter(time_step[:-1],x.ravel(),s=90)
plt.plot(time_step[:-1],x.ravel())
plt.scatter(time_step[1:],predictions)
plt.show()
if __name__ == '__main__':
main()
边栏推荐
- 七夕情人节:中英文祝福短信送给你
- FCN - the originator of semantic segmentation (based on tf-Kersa reproduction code)
- Verilog“七宗罪”
- 无人驾驶运用了什么技术,无人驾驶技术是
- Error EPERM operation not permitted, mkdir ‘Dsoftwarenodejsnode_cache_cacach两种解决办法
- 在线问题反馈模块实战(十八):实现excel台账文件记录批量导入功能
- Secondary network security competition C module MS17-010 batch scanning
- 2022的七夕,奉上7个精美的表白代码,同时教大家改源码快速自用
- 分布式计算实验2 线程池
- 【愚公系列】2022年07月 Go教学课程 027-深拷贝和浅拷贝
猜你喜欢
随机推荐
The sorting algorithm including selection, bubble, and insertion
RT-Thread Studio学习(十一)IIC
ContrstrainLayout的动画之ConstraintSet
FCN - the originator of semantic segmentation (based on tf-Kersa reproduction code)
MMDeploy部署实战系列【第四章】:onnx,tensorrt模型推理
MySQL内存淘汰策略
一天学会JDBC04:ResultSet的用法
Lightweight Backbone VGNetG Achieves "No Choice, All" Lightweight Backbone Network
【深度学习实践(二)】上手手写数字识别
将回调函数转为Flow
错误记录:TypeError: object() takes no parameters
unity 循环选择器
带你了解一下PHP搭建的电商商城系统
分布式计算MapReduce | Spark实验
MAML principle explanation and code implementation
MySQL - Row size too large (> 8126). Changing some columns to TEXT or BLOB
电商系统PC商城模块介绍
七牛云上传图片和本地上传
Praat:语音标注工具【保存为TextGrid文件】
【愚公系列】2022年07月 Go教学课程 027-深拷贝和浅拷贝









