当前位置:网站首页>RNN implementation regression model
RNN implementation regression model
2022-06-12 06:06:00 【Singing under the hedge】
List of articles
RNN Implement the regression model
use sin The curve predicts cos The curve of .
One 、 Code
""" View more, visit my tutorial page: https://mofanpy.com/tutorials/ My Youtube Channel: https://www.youtube.com/user/MorvanZhou Dependencies: torch: 0.4 matplotlib numpy """
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
# torch.manual_seed(1) # reproducible
# Hyper Parameters
TIME_STEP = 10 # rnn time step
INPUT_SIZE = 1 # rnn input size
LR = 0.02 # learning rate
# show data
steps = np.linspace(0, np.pi * 2, 100, dtype=np.float32) # float32 for converting torch FloatTensor
x_np = np.sin(steps)
y_np = np.cos(steps)
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=32, # rnn hidden unit
num_layers=1, # number of rnn layer
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
)
self.out = nn.Linear(32, 1)
def forward(self, x, h_state):
# x (batch, time_step, input_size)
# h_state (n_layers, batch, hidden_size)
# r_out (batch, time_step, hidden_size)
r_out, h_state = self.rnn(x, h_state)
outs = [] # save all predictions
for time_step in range(r_out.size(1)): # calculate output for each time step
outs.append(self.out(r_out[:, time_step, :]))
return torch.stack(outs, dim=1), h_state
# instead, for simplicity, you can replace above codes by follows
# r_out = r_out.view(-1, 32)
# outs = self.out(r_out)
# outs = outs.view(-1, TIME_STEP, 1)
# return outs, h_state
# or even simpler, since nn.Linear can accept inputs of any dimension
# and returns outputs with same dimension except for the last
# outs = self.out(r_out)
# return outs
rnn = RNN()
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.MSELoss()
h_state = None # for initial hidden state
plt.figure(1, figsize=(12, 5))
plt.ion() # continuously plot
for step in range(100):
start, end = step * np.pi, (step + 1) * np.pi # time range
# use sin predicts cos
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32,
endpoint=False) # float32 for converting torch FloatTensor
x_np = np.sin(steps)
y_np = np.cos(steps)
x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis]) # shape (batch, time_step, input_size)
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
prediction, h_state = rnn(x, h_state) # rnn output
# !! next step is important !!
h_state = h_state.data # repack the hidden state, break the connection from last iteration
loss = loss_func(prediction, y) # calculate loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
# plotting
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
plt.draw();
plt.pause(0.05)
plt.ioff()
plt.show()
# Two 、 Realization effect

边栏推荐
- sqlite交叉编译动态库
- zip 和.items()区别
- Unity3d display FPS script
- Leetcode-1512. Number of good pairs
- Directx11 advanced tutorial tiled based deffered shading
- Es6-es11 learning
- Directx11 advanced tutorial cluster based deffered shading
- Houdini script vex learning
- Error the main class com xxx. yyy. Application
- English语法_副词_有无ly,意义不同
猜你喜欢

Un mois de DDD hépatique.

Project and build Publishing

Simple spiral ladder generation for Houdini program modeling

How do I get the date and time from the Internet- How to get DateTime from the internet?

BRDF of directx11 advanced tutorial PBR (2)

A month's worth of DDD will help you master it

Sensor bringup 中的一些问题总结

Performance optimization metrics and tools

Analysis of memory management mechanism of (UE4 4.26) UE4 uobject

关于 Sensor flicker/banding现象的解释
随机推荐
Recursive implementation of exponential, permutation and combination enumerations
User login (medium)
First note
从传统网络IO 到 IO多路复用
Leetcode 第 80 場雙周賽題解
Jpg format and XML format files are separated into different folders
User login [next]
Getting started with houdininengine HDA and UE4
EBook list page
单通道图片的读入
nRF52832自定義服務與特性
Nrf52832 custom services and features
Analysis of memory management mechanism of (UE4 4.26) UE4 uobject
Project management and overall planning
JS预解析
Divide a folder image into training set and test set
Unity surface shader with template buffer
Leetcode sword finger offer II 033 Modified phrase
Leetcode-1706. Where does the club fall
数据库实验一:数据定义实验指导