当前位置:网站首页>RNN implementation regression model
RNN implementation regression model
2022-06-12 06:06:00 【Singing under the hedge】
List of articles
RNN Implement the regression model
use sin The curve predicts cos The curve of .
One 、 Code
""" View more, visit my tutorial page: https://mofanpy.com/tutorials/ My Youtube Channel: https://www.youtube.com/user/MorvanZhou Dependencies: torch: 0.4 matplotlib numpy """
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
# torch.manual_seed(1) # reproducible
# Hyper Parameters
TIME_STEP = 10 # rnn time step
INPUT_SIZE = 1 # rnn input size
LR = 0.02 # learning rate
# show data
steps = np.linspace(0, np.pi * 2, 100, dtype=np.float32) # float32 for converting torch FloatTensor
x_np = np.sin(steps)
y_np = np.cos(steps)
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=32, # rnn hidden unit
num_layers=1, # number of rnn layer
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
)
self.out = nn.Linear(32, 1)
def forward(self, x, h_state):
# x (batch, time_step, input_size)
# h_state (n_layers, batch, hidden_size)
# r_out (batch, time_step, hidden_size)
r_out, h_state = self.rnn(x, h_state)
outs = [] # save all predictions
for time_step in range(r_out.size(1)): # calculate output for each time step
outs.append(self.out(r_out[:, time_step, :]))
return torch.stack(outs, dim=1), h_state
# instead, for simplicity, you can replace above codes by follows
# r_out = r_out.view(-1, 32)
# outs = self.out(r_out)
# outs = outs.view(-1, TIME_STEP, 1)
# return outs, h_state
# or even simpler, since nn.Linear can accept inputs of any dimension
# and returns outputs with same dimension except for the last
# outs = self.out(r_out)
# return outs
rnn = RNN()
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.MSELoss()
h_state = None # for initial hidden state
plt.figure(1, figsize=(12, 5))
plt.ion() # continuously plot
for step in range(100):
start, end = step * np.pi, (step + 1) * np.pi # time range
# use sin predicts cos
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32,
endpoint=False) # float32 for converting torch FloatTensor
x_np = np.sin(steps)
y_np = np.cos(steps)
x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis]) # shape (batch, time_step, input_size)
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
prediction, h_state = rnn(x, h_state) # rnn output
# !! next step is important !!
h_state = h_state.data # repack the hidden state, break the connection from last iteration
loss = loss_func(prediction, y) # calculate loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
# plotting
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
plt.draw();
plt.pause(0.05)
plt.ioff()
plt.show()
# Two 、 Realization effect

边栏推荐
- Leetcode-2048. Next larger numerical balance
- Leetcode sword finger offer II 119 Longest continuous sequence
- (UE4 4.27) customize globalshader
- Leetcode buckle -10 Regular expression matching analysis [recursion and dynamic programming]
- 线程有哪些状态?
- EBook list page
- 获取图片的尺寸
- User login [next]
- Nrf52832 -- official routine ble_ app_ UART adds the LED feature to enable the computer UART and mobile app to control the LED on and off of the development board
- 摄像头拍摄运动物体,产生运动模糊/拖影的原因分析
猜你喜欢

BRDF of directx11 advanced tutorial PBR (2)

Leetcode-1706. Where does the club fall

Sqlite Cross - compile Dynamic Library

Why don't databases use hash tables?

Getting started with houdininengine HDA and UE4

Simple spiral ladder generation for Houdini program modeling

Why doesn't the database use binary tree, red black tree, B tree and hash table? Instead, a b+ tree is used

Leetcode-1535. Find the winner of the array game

(UE4 4.27) add globalshder to the plug-in
![[untitled]](/img/75/599c5b13dd483fad50f73ddb431989.jpg)
[untitled]
随机推荐
Une explication du 80e match bihebdomadaire de leetcode
Unity3d display FPS script
Unity custom translucent surface material shader
Redis队列
MySQL 主从,6 分钟带你掌握
Idea common configuration
Makefile文件编写快速掌握
项目开发流程简单介绍
Script for unity3d to recursively search for a node with a specific name from all child nodes of a node
摄像头拍摄运动物体,产生运动模糊/拖影的原因分析
Project progress on February 28, 2022
基于LFFD模型目标检测自动标注生成xml文件
Jpg format and XML format files are separated into different folders
Project technical structure
[Yu Yue education] basic reference materials of accounting of Nanjing Normal University
Front desk display LED number (number type on calculator)
nus_data_handler源码解读结构体等数据类型
Leetcode-1260. 2D mesh migration
How to increase heap size of JVM [duplicate] - how to increase heap size of JVM [duplicate]
Why don't databases use hash tables?