当前位置:网站首页>RuntimeError:Input and parameter tensors are not at the same device, found input tensor at cuda:0 an
RuntimeError:Input and parameter tensors are not at the same device, found input tensor at cuda:0 an
2022-06-12 08:43:00 【嘿 是我】
错误原因: 输入x和输出y(或模型参数)存放的位置不同所产生的
这种错误主要是因为输入x和输出y(或模型参数)存放的位置不同所产生的
如果你是错误1: 输入x在cuda(gpu)中, 模型参数在cpu中
想把输入x放入到gpu中,一般就是找到输入参数x,然后再调用使用参数x之前添加一行代码x.to(device)(其中device=“cuda”)
如果你是错误2: 输入x在cpu中, 模型参数cuda(gpu)在中
找到定义model的代码,在定义的后面添加一行代码 model.to(device)
具体操作如下:
错误1:RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cuda:0 and parameter tensor at cpu
1.1 输入x在cuda(gpu)中, 模型参数在cpu中
测试代码demo: :
此时输入x在gpu, 但是model存放在cpu中 所以运行如下代码会报错误1
import torch
import torch.nn as nn
from torch.nn import LSTM
device = "cuda" if torch.cuda.is_available() else "cpu" # 有gpu用gpu, 没有就用cpu
x = torch.Tensor([[1,2,3], [2,3,4]]) # x shape (2,3) (seq_len, 词向量维度)
class Testmodel(nn.Module):
def __init__(self, input_dim, lstm_layer, lstm_hidden_dim, dropout):
super(Testmodel, self).__init__()
self.lstm_encoding = LSTM(input_dim, num_layers=lstm_layer, hidden_size=lstm_hidden_dim,
dropout=0.5) #
def forward(self, x: torch.Tensor):
output, (hn, cn) = self.lstm_encoding(x)
return output
model = Testmodel(
input_dim=3,
lstm_layer=2,
lstm_hidden_dim=4,
dropout=0.5,
)
# 此时输入x在gpu, 但是model存放在cpu中 所以会报错
x = x.to(device) # 将x放入到gpu内存中
output = model(x) # 调用forward方法 x (2,3) lstm 输入维度3, 输出维度4,
print(output) # output shape (2,4)
1.2 解决方法
方法1:直接注释下面代码,将输入x放入到cpu内存中和输出保持一致
x = x.to(device) # 将x放入到gpu内存中
方法2(推荐):添加一行代码model.to(device),将模型的参数放入到gpu中,和输入x位置保持一致,改完后案例代码如下
import torch
import torch.nn as nn
from torch.nn import LSTM
device = "cuda" if torch.cuda.is_available() else "cpu" # 有gpu用gpu, 没有就用cpu
x = torch.Tensor([[1,2,3], [2,3,4]]) # x shape (2,3) (seq_len, 词向量维度)
class Testmodel(nn.Module):
def __init__(self, input_dim, lstm_layer, lstm_hidden_dim, dropout):
super(Testmodel, self).__init__()
self.lstm_encoding = LSTM(input_dim, num_layers=lstm_layer, hidden_size=lstm_hidden_dim,
dropout=0.5) #
def forward(self, x: torch.Tensor):
output, (hn, cn) = self.lstm_encoding(x)
return output
model = Testmodel(
input_dim=3,
lstm_layer=2,
lstm_hidden_dim=4,
dropout=0.5,
)
model.to(device) # !!!!!!!!!!!!新添加的代码在这里
# 此时输入x在gpu, 但是model在cpu中 所以会报错
x = x.to(device) # 将x放入到gpu内存中
output = model(x) # 调用forward方法 x (2,3) lstm 输入维度3, 输出维度4,
print(output) # output shape (2,4)
错误2.RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cpu and parameter tensor at cuda:0
2.1 输入在cpu中, 输出(模型参数)在cuda(gpu)中
测试代码demo:
此时输入x在cpu, 但是model存放在gpu中 所以运行如下代码会报错误2
import torch
import torch.nn as nn
from torch.nn import LSTM
device = "cuda" if torch.cuda.is_available() else "cpu" # 有gpu用gpu, 没有就用cpu
x = torch.Tensor([[1,2,3], [2,3,4]]) # x shape (2,3) (seq_len, 词向量维度)
class Testmodel(nn.Module):
def __init__(self, input_dim, lstm_layer, lstm_hidden_dim, dropout):
super(Testmodel, self).__init__()
self.lstm_encoding = LSTM(input_dim, num_layers=lstm_layer, hidden_size=lstm_hidden_dim,
dropout=0.5) #
def forward(self, x: torch.Tensor):
output, (hn, cn) = self.lstm_encoding(x)
return output
model = Testmodel(
input_dim=3,
lstm_layer=2,
lstm_hidden_dim=4,
dropout=0.5,
)
model.to(device) # 将模型参数放到gpu中
# 此时输入x在gpu, 但是model在cpu中 所以会报错
output = model(x) # 调用forward方法 x (2,3) lstm 输入维度3, 输出维度4,
print(output) # output shape (2,4)
2.2 解决方法
方法1:找到代码model.to(device)直接注释, 将模型参数放入到cpu内存中和输入x位置保持一致
model.to(device) # 将模型参数放到gpu中
方法2(推荐):添加一行代码x = x.to(device) ,将模型的参数放入到gpu中,和输入x位置保持一致,改完后案例代码如下
import torch
import torch.nn as nn
from torch.nn import LSTM
device = "cuda" if torch.cuda.is_available() else "cpu" # 有gpu用gpu, 没有就用cpu
x = torch.Tensor([[1,2,3], [2,3,4]]) # x shape (2,3) (seq_len, 词向量维度)
class Testmodel(nn.Module):
def __init__(self, input_dim, lstm_layer, lstm_hidden_dim, dropout):
super(Testmodel, self).__init__()
self.lstm_encoding = LSTM(input_dim, num_layers=lstm_layer, hidden_size=lstm_hidden_dim,
dropout=0.5) #
def forward(self, x: torch.Tensor):
output, (hn, cn) = self.lstm_encoding(x)
return output
model = Testmodel(
input_dim=3,
lstm_layer=2,
lstm_hidden_dim=4,
dropout=0.5,
)
model.to(device) # 将模型参数放入到gpu内存中
# 此时输入x在gpu, 但是model在cpu中 所以会报错
x = x.to(device) # !!!!!!!!!!!!新添加的代码在这里
output = model(x) # 调用forward方法 x (2,3) lstm 输入维度3, 输出维度4,
print(output) # output shape (2,4)
本人水平有限, 如有错误欢迎指正交流
边栏推荐
- 第八章-数据处理的两个基本问题
- The electrical fire detector monitors each power circuit in real time
- svg中viewbox图解分析
- 报错:文件夹在另一个程序中打开无法删除怎么办
- Regularization to limit the number of digits after the decimal point of an input number
- Build personal blog and web.
- Adjust SVG width and height
- Vscade debug TS
- Engineers learn music theory (III) interval mode and chord
- 《MATLAB 神经网络43个案例分析》:第8章 GRNN网络的预测----基于广义回归神经网络的货运量预测
猜你喜欢

三国杀周边--------猪国杀题解

The residual pressure monitoring system ensures the smoothness of the fire evacuation passage in case of fire, and protects the safe operation of large high-rise buildings and the safety of people's l

JVM学习笔记:三 本地方法接口、执行引擎
![[advanced pointer 2] array parameter transfer & pointer parameter transfer & function pointer & function pointer array & callback function](/img/90/447d601a8c338cdd5a6674a2dc59ae.png)
[advanced pointer 2] array parameter transfer & pointer parameter transfer & function pointer & function pointer array & callback function

In the era of intelligent manufacturing, how do enterprises carry out digital transformation

Configuration and principle of MSTP

This article is required for the popularization of super complete MES system knowledge

【进阶指针一】字符数组&数组指针&指针数组

MPLS的原理与配置

What is the MES system? What is the operation process of MES system?
随机推荐
【字符集八】char8_t、char16_t、char32_t、wchar、char
Handling abnormal data
Webrtc series - mobile terminal hardware coding supports simulcast
Shell basic syntax -- array
Loading circling effect during loading
The era of post MES system has come gradually
Regular expressions in JS
Vscode download slow solution
【字符集九】gbk拷贝到Unicode会乱码?
MES helps enterprises to transform intelligently and improve the transparency of enterprise production
Vscade debug TS
Background color translucent
Detailed explanation of private, public and interface attributes in cmake
报错:文件夹在另一个程序中打开无法删除怎么办
调整svg宽高
Query in MySQL
Database foundation -- normalization and relational schema
Engineers learn music theory (II) scale and tendency
Py & go programming skills: logic control to avoid if else
Install iptables services and open ports