当前位置:网站首页>25.时间序列预测实战
25.时间序列预测实战
2022-08-04 07:03:00 【派大星的最爱海绵宝宝】
时间序列预测实战
[b,50,1],b为1时,可以理解为只送入一条曲线,每一条曲线有50点的数据,每个点数据都是实数。
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
start是随机的,是每次开始的起点。
我们需要完成的功能是,对于一条曲线,给出红色部分时,要求预测出蓝色部分曲线。
x是给定的0到48的部分,y需要预测出1到49的部分。
Train
out[b,seq_len,h]
h[b,1,h]
hidden_prev是h0,最开始是一个batch,一层,h是10。
我们将output和y之间进行一个MSE求误差,根据这个误差进行网络的更新。
hidden_prev=torch.zeros(1,1,hidden_size)
for iter in range(6000):
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
output,hidden_prev=model(x,hidden_prev)
hidden_prev=hidden_prev.detach()
loss=criteon(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter %100 ==0:
print("Iteration:{} loss:{}".format(iter,loss.item()))
Test
先将预测值做一个空的数组。
x[1,seq,1]。
每次的input等于pred出来的点,每次只画一个点,最后进行串联。
predictions=[]
input=x[:,0,:]
for _ in range(x.shape[1]):
input=input.view(1,1,1)
(pred,hidden_prev)=model(input,hidden_prev)
input=pred
predictions.append(pred.detach().numpy().ravel()[0])
结果
代码
import numpy
import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
num_time_steps = 50
input_size=1
hidden_size=16
output_size=1
lr = 0.01
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.rnn=nn.RNN(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True
)
self.linear=nn.Linear(hidden_size,output_size)
def forward(self,x,hidden_prev):
out,hidden_prev=self.rnn(x,hidden_prev)
#[1,seq,h]->[seq,h]
out=out.view(-1,hidden_size)
out=self.linear(out) #[seq,h]->[seq,1]
out=out.unsqueeze(dim=0) #->[1,seq,-1]
return out,hidden_prev
def main():
model=Net()
criteon=nn.MSELoss()
optimizer=optim.Adam(model.parameters(),lr)
hidden_prev=torch.zeros(1,1,hidden_size)
for iter in range(6000):
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
output,hidden_prev=model(x,hidden_prev)
hidden_prev=hidden_prev.detach()
loss=criteon(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter %100 ==0:
print("Iteration:{} loss:{}".format(iter,loss.item()))
predictions=[]
input=x[:,0,:]
for _ in range(x.shape[1]):
input=input.view(1,1,1)
(pred,hidden_prev)=model(input,hidden_prev)
input=pred
predictions.append(pred.detach().numpy().ravel()[0])
x=x.data.numpy().ravel()
y=y.data.numpy()
plt.scatter(time_step[:-1],x.ravel(),s=90)
plt.plot(time_step[:-1],x.ravel())
plt.scatter(time_step[1:],predictions)
plt.show()
if __name__ == '__main__':
main()
边栏推荐
- 一天学会JDBC06:PrepaerdStatemtnt
- unity 循环选择器
- C语言指针
- ContrstrainLayout的动画之ConstraintSet
- MySQL - Row size too large (> 8126). Changing some columns to TEXT or BLOB
- 错误记录:TypeError: object() takes no parameters
- 有人试过用NPGsql驱动连接openGauss开发应用的吗?
- 最强分布式锁工具:Redisson
- Lightweight Backbone VGNetG Achieves "No Choice, All" Lightweight Backbone Network
- 一天搞定JDBC02:开启事务
猜你喜欢
有趣的USB接口和颜色分类
中职网络安全竞赛C模块MS17-010批量扫描
RT-Thread Studio学习(十一)IIC
【愚公系列】2022年07月 Go教学课程 027-深拷贝和浅拷贝
千万级别的表分页查询非常慢,怎么办?
全国职业院校技能大赛网络安全竞赛之应急响应
[想要访问若依后台]若依框架报错401请求访问:error认证失败,无法访问系统资源
babylon 里面加gltf 模型
FCN - the originator of semantic segmentation (based on tf-Kersa reproduction code)
The national vocational skills contest competition of network security emergency response
随机推荐
在线问题反馈模块实战(十八):实现excel台账文件记录批量导入功能
unity 循环选择器
Distributed Computing Experiment 3 PRC-based Book Information Management System
【我想要老婆】
10个程序员可以接私活的平台和一些建议,赚麻...
一天搞定JDBC02:开启事务
npm包发布与迭代
分布式计算实验1 负载均衡
GIS数据与CAD数据间带属性字段互相转换还原工具,解决ArcGIS等软件进行GIS数据转CAD数据无法保留属性字段问题
Activiti 工作流引擎 详解
有趣的USB接口和颜色分类
【学习笔记】状压dp
Praat:语音标注工具【保存为TextGrid文件】
轻量化Backbone VGNetG成就“不做选择,全都要”轻量化主干网络
SQL去重的三种方法汇总
关于我写的循环遍历
[想要访问若依后台]若依框架报错401请求访问:error认证失败,无法访问系统资源
CSRF和SSRF漏洞
中断和异常的处理与抢占式多任务
两日总结六