当前位置:网站首页>RNN model
RNN model
2022-06-12 06:07:00 【Singing under the hedge】
List of articles
#RNN Model
One 、 Code
import torch
import torch.nn as nn
x_input = torch.randn(2,3,10)
class RNN(nn.Module):
def __init__(self,input_size,hidden_size,batch_first=False):
super(RNN,self).__init__()
self.rnn_cell = nn.RNNCell(input_size,hidden_size)
self.batch_first = batch_first
self.hidden_size = hidden_size
def _initialize_hidden(self,batch_size):
return torch.zeros((batch_size,self.hidden_size))
def forward(self,inputs,initial_hidden=None):
if self.batch_first:
batch_size,seq_size,feat_size = inputs.size()
inputs = inputs.permute(1,0,2)
else:
seq_size,batch_size,feat_size = inputs
hiddens = []
if initial_hidden is None:
initial_hidden = self._initialize_hidden(batch_size)
initial_hidden = initial_hidden.to(inputs.device)
hidden_t = initial_hidden
for t in range(seq_size):
hidden_t = self.rnn_cell(inputs[t],hidden_t)
hiddens.append(hidden_t)
hiddens = torch.stack(hiddens)
if self.batch_first:
hiddens = hiddens.permute(1,0,2)
return hiddens
model = RNN(10,15,batch_first=True)
output = model(x_input)
print(output)
Two 、 Realization effect

边栏推荐
猜你喜欢
![How to increase heap size of JVM [duplicate] - how to increase heap size of JVM [duplicate]](/img/65/a214d137e230b1a1190feb03660f2c.jpg)
How to increase heap size of JVM [duplicate] - how to increase heap size of JVM [duplicate]

2D human pose estimation for pose estimation - pifpaf:composite fields for human pose estimation

EBook list page

Simple spiral ladder generation for Houdini program modeling

Directx11 advanced tutorial cluster based deffered shading

EBook upload

分段贝塞尔曲线

A preliminary understanding of function

sqlite交叉編譯動態庫

Leetcode-1535. Find the winner of the array game
随机推荐
Json-c common APIs
Database Experiment 2: data update
China Aquatic Fitness equipment market trend report, technical innovation and market forecast
Review notes of naturallanguageprocessing based on deep learning
Nrf52832 services et fonctionnalités personnalisés
Who is more fierce in network acceleration? New king reappeared in CDN field
sqlite交叉编译动态库
C # converts the hexadecimal code form of text to text (ASCII)
EBook list page
Leetcode-1705. Maximum number of apples to eat
Leetcode-1663. Minimum string with given value
Leetcode-1260. 2D mesh migration
Poisson disk sampling for procedural placement
Image processing: image[:,:,:: -1], image[:,: -1,:], image[:,: -1,:]
Leetcode 第 80 场双周赛题解
Data integration framework seatunnel learning notes
Unity C script implements AES encryption and decryption
哈工大信息内容安全实验
Glossary of Chinese and English terms for pressure sensors
肝了一个月的 DDD,一文带你掌握