当前位置:网站首页>torch中乘法整理,*&torch.mul()&torch.mv()&torch.mm()&torch.dot()&@&torch.mutmal()
torch中乘法整理,*&torch.mul()&torch.mv()&torch.mm()&torch.dot()&@&torch.mutmal()
2022-07-27 05:01:00 【CharlesLC的博客】
*位置乘
符号*在pytorch中是按位置相乘,存在广播机制。
例子:
vec1 = torch.arange(4)
vec2 = torch.tensor([4,3,2,1])
mat1 = torch.arange(12).reshape(4,3)
mat2 = torch.arange(12).reshape(3,4)
print(vec1 * vec2)
print(mat2 * vec1)
print(mat1 * mat1)
Output:
tensor([0, 3, 4, 3])
tensor([[ 0, 1, 4, 9],
[ 0, 5, 12, 21],
[ 0, 9, 20, 33]])
tensor([[ 0, 1, 4],
[ 9, 16, 25],
[ 36, 49, 64],
[ 81, 100, 121]])
torch.mul():数乘
官方解释:
就是两个变量对应元素相乘,other可以为一个数,也可以为一个tensor变量
torch.mul()支持广播机制
例子1:
‘’‘python
In[1]: vec = torch.randn(3)
In[2]: vec
Out[1]: tensor([0.3550, 0.0975, 1.3870])
In[3]: torch.mul(vec, 5)
Out[2]: tensor([1.7752, 0.4874, 6.9348])
‘’’
例子2:
‘’'python
In[1]: vec = torch.randn(3)
In[2]: vec
Out[1]: tensor([1.7752, 0.4874, 6.9348])
In[3]: mat = torch.randn(4).view(-1,1)
In[4]: mat
Out[2]: tensor([[-1.5181],
[ 0.4905],
[-0.3388],
[ 0.5626]])
In[5]:torch.mul(vec,mat)
Out[3]:tensor([[-0.5390, -0.1480, -2.1055],
[ 0.1741, 0.0478, 0.6803],
[-0.1203, -0.0330, -0.4699],
[ 0.1998, 0.0548, 0.7803]])
‘’’
torch.mv():矩阵向量乘法
官方文档写道:Performs a matrix-vector product of the matrix input and the vector vec.
说明torch.mv(input, vec, *, out=None)->tensor只支持矩阵向量乘法,如果input为 n × m n\times m n×m的,vec向量的长度为m,那么输出为 n × 1 n\times 1 n×1的向量。torch.mv()不支持广播机制
例子:
In[1]: vec = torch.arange(4)
In[2]: mat = torch.arange(12).reshape(3,4)
In[3]: torch.mv(mat, vec)
Out[1]: tensor([14, 38, 62])
torch.mm() 矩阵乘法
官方文档写道:Performs a matrix multiplication of the matrices input and mat2.torch.mm(input , mat2, *, out=None) → Tensor
对矩阵input 和mat2进行相乘。 如果input 是一个n×m张量,mat2 是一个 m×p张量,将会输出一个 n×p张量out。torch.mm()不支持广播机制
这个就是线性代数中的矩阵乘法。
例子:
In[1]: mat1 = torch.arange(12).reshape(3,4)
In[2]: mat2 = torch.arange(12).reshape(4,3)
In[3]: torch.mm(mat1, mat2)
Out[1]: tensor([[ 42, 48, 54],
[114, 136, 158],
[186, 224, 262]])
torch.dot() 点乘积
官方文档写道:Computes the dot product of two 1D tensors.
只能支持两个一维向量,与numpy中dot()方法不同。torch.dot(input, other, *, out=None) → Tensor
例子:
In[1]: torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1]))
Out[1]: tensor[7]
@操作
torch中的@操作是可以实现前面几个函数,是一种强大的操作。mat1 @ mat2
- 若mat1和mat2都是两个一维向量,那么对应操作就是torch.dot()
- 若mat1是二维向量,mat2是一维向量,那么对应操作就是torch.mv()
- 若mat1和mat2都是两个二维向量,那么对应操作就是torch.mm()
vec1 = torch.arange(4)
vec2 = torch.tensor([4,3,2,1])
mat1 = torch.arange(12).reshape(4,3)
mat2 = torch.arange(12).reshape(3,4)
print(vec1 @ vec2) # 两个一维向量
print(mat2 @ vec1) # 一个二维和一个一维
print(mat1 @ mat2) # 两个二维向量
Output:
tensor(10)
tensor([14, 38, 62])
tensor([[ 20, 23, 26, 29],
[ 56, 68, 80, 92],
[ 92, 113, 134, 155],
[128, 158, 188, 218]])
torch.matmul()
torch.matmul()与@操作类似,但是torch.matmul()不止局限于一维和二维,可以进行高维张量的乘法。torch.matmul(input, other, *, out=None) → Tensor 支持广播
torch.matmul()运算取决于input和other张量的大小:
- 如果输入的两个张量都是一维的,那么返回点积,对应的操作就是
torch.dot()
- 如果输入的两个张量都是一维的,那么返回点积,对应的操作就是
- 如果输入的两个张量都是二维的,那么返回矩阵乘积,对应的操作就是
torch.mm()
- 如果输入的两个张量都是二维的,那么返回矩阵乘积,对应的操作就是
- 如果输入的第一个张量是二维的,第二个张量是一维的,则返回矩阵向量乘积,对应
torch.mv()
- 如果输入的第一个张量是二维的,第二个张量是一维的,则返回矩阵向量乘积,对应
- 如果输入的第一个张量是一维的,第二个参数是二维的,那么
torch.matmul()操作会先将第一个张量的维度前面添加1,在执行矩阵相乘后,再将添加的维度移除。
- 如果输入的第一个张量是一维的,第二个参数是二维的,那么
- 如果两个参数至少是维一维的且至少一个参数是N维(N>2),则进行批处理矩阵乘法,如果第一个参数是一维,则在第一个参数的维度前面加1,在进行批处理矩阵相乘后在删除。如果第二个参数是一维的,则在第二个参数的维度后面加1,在进行批处理矩阵相乘后再删除。
例子1(对应1到4运算):
vec1 = torch.tensor([1,2,3,4])
vec2 = torch.tensor([4,3,2,1])
mat1 = torch.arange(12).reshape(3,4)
mat2 = torch.arange(12).reshape(4,3)
print(torch.matmul(vec1, vec2)) # 两个向量都是 一维的, 类似torch.dot()操作
print(torch.matmul(mat1, mat2)) # 两个向量都是 二维的, 类似torch.mm()操作
print(torch.matmul(mat1, vec1)) # 第一个向量是 二维的,第二个向量是 一维的,类似 torch.mv()操作
print(torch.matmul(vec1, mat2)) # 第一个向量是 一维的,第二个向量是二维的
Output:
tensor(20)
tensor([[ 42, 48, 54],
[114, 136, 158],
[186, 224, 262]])
tensor([ 20, 60, 100])
tensor([60, 70, 80])
第五种情况,当出现多维的情况:记住一点,如果多维,永远是最后面两维度相乘,然后再将前面的维度补上
例子2: 如果第一个向量是一维的,第二个向量是多维的。
tensor1 = torch.randn(2)
tensor2 = torch.randn(10,2,3)
print(torch.matmul(tensor1, tensor2).size())
Output:
torch.Size([10, 3])
先将tensor2前面维度10作为batch提出来,使其变成二维(2 * 3),将tensor1的维度前面加1,那么就变成了1*2(二维),然后维度1*2和维度2 * 3做矩阵乘法,得到1 * 3,再将tensor1添加的维度1删除,最终得到维度10*3。
例子3:如果第一个向量是多维的,第二个向量是一维的
tensor1 = torch.randn(10,3,4)
tensor2 = torch.randn(4)
print(torch.matmul(tensor1, tensor2).shape)
Output:
torch.Size([10, 3])
先将tensor2的维度后面加1,那么tensor2维度就变成了4*1(二维),接着,把tensor1前面的维度10作为batch提出来,使其变成二维(3 * 4),然后维度3 * 4和维度4 * 1做矩阵乘法,得到3 * 1,再将tensor2添加的维度1删掉,最终加上batch的维度,得到维度10 * 3
例子4:第一个向量3维,第二个向量2维
tensor1 = torch.randn(2,5,3)
tensor2 = torch.randn(3,4)
print(torch.matmul(tensor1, tensor2).shape)
Output:
torch.Size([2, 5, 4])
先将tensor1中多出的一维提取出来,其余部分做矩阵乘法
例子5: 第一个向量2维,第二个向量3维
tensor1 = torch.randn(5,3)
tensor2 = torch.randn(2,3,4)
print(torch.matmul(tensor1, tensor2).shape)
Output:
torch.Size([2, 5, 4])
先将tensor2中多出的一维提取出来,其余部分做矩阵乘法
例子6: 当两个都是多维
tensor1 = torch.randn(5,1,5,3)
tensor2 = torch.randn(2,3,4)
print(torch.matmul(tensor1, tensor2).shape)
Output:
torch.Size([5, 2, 5, 4])
先将tensor1中多余的一维提取出来,剩下三维,将tensor1中做广播机制,变成2 * 5 * 3,接着将tensor1和tensor2中最后两维做矩阵乘法,得到5 * 4, 最终得到维度(5 * 2 * 5 * 4)
边栏推荐
- JVM上篇:内存与垃圾回收篇九--运行时数据区-对象的实例化,内存布局与访问定位
- JVM上篇:内存与垃圾回收篇七--运行时数据区-堆
- LeetCode刷题之322 Coin Change
- JVM Part 1: memory and garbage collection part 8 - runtime data area - Method area
- Integrate SSM
- Event Summary - common summary
- B1031 check ID card
- 数据库设计——关系数据理论(超详细)
- Acticiti中startProcessInstanceByKey方法在variable表中的如何存储
- Li Kou achieved the second largest result
猜你喜欢

整合SSM

Scientific Computing Library - numpy

Translation of robot and precise vehicle localization based on multi sensor fusion in diverse city scenes

Inspiration from "flying man" Jordan! Another "arena" opened by O'Neill

MQ message queue is used to design the high concurrency of the order placing process, the generation scenarios and solutions of message squeeze, message loss and message repetition

素数筛选(埃氏筛法,区间筛法,欧拉筛法)

JVM Part 1: memory and garbage collection part 7 -- runtime data area heap

Li Kou achieved the second largest result

2021 OWASP top 5: security configuration error

Introduction to Kali system ARP (network disconnection sniffing password packet capturing)
随机推荐
Install pyGame
Introduction to Web Framework
JVM Part 1: memory and garbage collection part 5 -- runtime data area virtual machine stack
Test basis 5
What should test / development programmers over 35 do? Many objective factors
Bean's life cycle & dependency injection * dependency auto assembly
一、MySQL基础
Create datasource using Druid connection pool
Constraints of MySQL table
Static and final keyword learning demo exercise
探寻通用奥特能平台安全、智能、性能的奥秘!
Gradio quickly builds ml/dl Web Services
Solution and principle analysis of feign call missing request header
JVM part I: memory and garbage collection part II -- class loading subsystem
B1022 a+b in d-ary
牛客剑指offer--JZ12 矩阵中的路径
B1025 反转链表*******
LeetCode之268.Missing number
JVM上篇:内存与垃圾回收篇十--运行时数据区-直接内存
JVM上篇:内存与垃圾回收篇十四--垃圾回收器