当前位置:网站首页>PyTorch 中的乘法:mul()、multiply()、matmul()、mm()、mv()、dot()

PyTorch 中的乘法:mul()、multiply()、matmul()、mm()、mv()、dot()

2022-07-07 13:49:00 小瓶盖的猪猪侠

torch.mul()

函数功能:逐个对 input 和 other 中对应的元素相乘。

本操作支持广播,因此 input 和 other 均可以是张量或者数字

import torch
a = torch.randn((1,2))
b = torch.randn((2,1))
print(a,b)
torch.mul(a,b)

在这里插入图片描述

torch.multiply()

torch.mul() 的别称

torch.matmul()

matmul可以进行张量乘法, 输入可以是高维.

torch.dot()

函数功能:计算 input 和 output 的点乘,此函数要求 input 和 output 都必须是一维的张量(其 shape 属性中只有一个值)!并且要求两者元素个数相同!

import torch 

x = torch.Tensor([1,2])
y = torch.Tensor([3,4])
z = torch.dot(x,y)
z

在这里插入图片描述

torch.mm()

函数功能:实现线性代数中的矩阵乘法(matrix multiplication):(n×m) × (m×p) = (n×p) 。

本函数不允许广播!

import torch
x = torch.randn((3,4))
y = torch.randn((4,5))
z = torch.mm(x,y)
z

在这里插入图片描述

torch.mv()

函数功能:实现矩阵和向量(matrix × vector)的乘法,要求 input 的形状为 n×m,output 为 torch.Size([m])的一维 tensor.

import torch
x = torch.randn((3,4))
y = torch.randn(4)
z = torch.mv(x,y)
z
原网站

版权声明
本文为[小瓶盖的猪猪侠]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_29983883/article/details/125573047

随机推荐