当前位置:网站首页>"Torch" tensor multiplication: matmul, einsum
"Torch" tensor multiplication: matmul, einsum
2022-08-01 20:01:00 【panbaoran913】
参考博文:《张量相乘matmul函数》
一、torch.matmul
matmul(input, other, out = None)
函数对 input 和 other Matrix multiplication of two tensors.torch.matmul The function has many overloaded functions depending on the tensor dimension of the passed arguments.
When multiplying tensors,并不是标准的 ( m , n ) × ( n , l ) = ( m , l ) (m,n) \times (n,l) =(m,l) (m,n)×(n,l)=(m,l)的形式.
三、1D and 2D multiplication
3.1 1D multiplied by 2D: ( m ) × ( m , n ) = ( n ) (m) \times (m,n)=(n) (m)×(m,n)=(n)
A1 =torch.FloatTensor(size=(4,))
A2=torch.FloatTensor(size=(4,3))
A12=torch.matmul(A1,A2)
A12.shape # (3,)
3.2 Two-dimensional by one-dimensional: ( m , n ) ∗ ( n ) = ( m ) (m,n)*(n)=(m) (m,n)∗(n)=(m)
A3=torch.FloatTensor(size=(3,4))
A31=torch.matmul(A3,A1)
A31.shape #(3,)
四、Multiply 2D and 3D
4.1 2D multiplication3维: ( m , n ) × ( b , n , l ) = ( b , m , l ) (m,n)\times (b, n, l)=(b, m, l) (m,n)×(b,n,l)=(b,m,l).The expansion plan is ( b , m , n ) × ( b , n , l ) = ( b , m , l ) (b, m,n)\times (b, n,l) =(b, m,l) (b,m,n)×(b,n,l)=(b,m,l)
B1=torch.FloatTensor(size=(2,3))
B2=torch.FloatTensor(size=(5,3,4))
B12=torch.matmul(B1,B2)
B12.shape #(5,2,4)
等价方案:
B12_=torch.einsum("ij,bjk->bik",B1,B2)
torch.sum(B12==B12_)#40=2*4*5
4.2 3D times 2D: ( b , m , n ) × ( n , l ) = ( b , m , l ) (b, m, n)\times (n,l)=(b, m,l) (b,m,n)×(n,l)=(b,m,l).
B2=torch.FloatTensor(size=(5,3,4))
B3=torch.FloatTensor(size=(4,2))
B23=torch.matmul(B2,B3)
B23.shape #(5,3,2)
等价方案:
BB23_ =torch.einsum("bij,jk->bik",[B2,B3])
BB23_.shape #(5,3,2)
torch.sum(B23==BB23_)#30=5*3*2
4. 3 Two-dimensional expansion into three-dimensional way
方式一:The first tensor is expanded from two dimensions to three dimensions
B1(2,3)–>B1_(5,2,3)
B1=torch.FloatTensor(size=(2,3))
B1_ =torch.unsqueeze(B1,axis=0) #升维
print(B1_.shape) #torch.Size([1, 2, 3])
B11 =torch.cat([B1_,B1_,B1_,B1_,B1_],axis=0)#合并-->扩维
print(B11.shape) #torch.Size([5, 2, 3])
比较 B 1 ( 2 , 3 ) × B 2 ( 5 , 3 , 4 ) 与 B 11 ( 5 , 2 , 3 ) × B 2 ( 5 , 3 , 4 ) B1(2,3)\times B2(5,3,4)与B11(5,2,3)\times B2(5,3,4) B1(2,3)×B2(5,3,4)与B11(5,2,3)×B2(5,3,4)的结果
B112=torch.matmul(B11,B2)#(5,2,3)*(5,3,4)
torch.sum(B112==B12)#40=5*2*3
Indicates that both values are exactly the same.Let's further explore the mechanism of its multiplication.
我们拿B1(2,3)与B2(5,3,4)Multiply the first matrix in ,to see if it is equal to the first matrix in . The following proofs are equivalent
B12_0=torch.matmul(B1,B2[0])
B112[0]==B12_[0]
out:
tensor([[True, True, True, True],
[True, True, True, True]])
2Dimension multiplied by3Dimensional Matrix Demonstration Diagram
方式二:The second tensor is expanded from two dimensions to three dimensions
B3(4,2)–>B3_(5, 4, 2)
B3_=torch.unsqueeze(B3,axis=0)
print(B3_.shape)#(1,4,2)
B33 =torch.cat([B3_,B3_,B3_,B3_,B3_],axis=0)
print(B33.shape)#(5,4,2)
B233 =torch.matmul(B2,B33)
print(B233.shape) #(5,3,2)
Compare the results of the two multiplications:
print(torch.sum(B233==B23_)) #30
print(torch.sum(B233==B23)) #30
提醒:torch的FloatTensor中出现了nan值,seems to be unequal.
五、Multiply 2D and 4D
5.1 Two-dimensional by four-dimensional: ( m , n ) × ( b , c , n , l ) = ( b , c , m , l ) (m,n)\times (b,c,n,l) =(b,c,m,l) (m,n)×(b,c,n,l)=(b,c,m,l)
B1=torch.FloatTensor(size=(2,3))
B4 =torch.FloatTensor(size=(7,5,3,4))
B14 =torch.matmul(B1,B4)
print(B14.shape) #(7, 5, 2, 4)
等价方案
B14_= torch.einsum("mn,bcnl->bcml",[B1,B4])
print(torch.sum(B14==B14_))#280=7*5*2*4
升维
## 升维
B11 = torch.unsqueeze(B1,dim=0)
B11 = torch.concat([B11,B11,B11,B11,B11],dim=0)
print(B11.shape)#(5,2,3)
B111 = torch.unsqueeze(B11,dim=0)
B111 =torch.concat([B111,B111,B111,B111,B111,B111,B111],dim = 0)
print(B111.shape)#(7,5,2,3)
广播后的4Dimension multiplied by4维
B1114 = torch.matmul(B111,B4)
print(B1114.shape)#(7,5,3,4)
print(torch.sum(B1114==B14))#280
5.2 Four dimensions multiplied by two dimensions: ( b , c , n , l ) × ( l , p ) = ( b , c , n , p ) (b,c,n,l) \times (l,p)= (b,c,n,p) (b,c,n,l)×(l,p)=(b,c,n,p)
4Dimension multiplied by2维
B43 = torch.matmul(B4,B3)
print("B43 shape",B43.shape) #(7,5,3,2)
等价形式
B43_ = torch.einsum("bcnl,lp->bcnp",[B4,B3])
print("B4 is nan",torch.sum(B4.isnan()))#0
print(torch.sum(B43==B43_))#210 =7*5*3*2
升维
B33 =torch.unsqueeze(B3,dim=0)
B33 = torch.concat([B33,B33,B33,B33,B33],dim =0)
B333 = torch.unsqueeze(B33,dim =0)
B333 =torch.concat([B333,B333,B333,B333,B333,B333,B333],dim =0)
print("B333 shape is",B333.shape)#(7,5,4,2)
广播后4Dimension multiplied by4维
B4333 =torch.matmul(B4,B333)
print("B4333 shape is",B4333.shape)#(7,5,3,2)
边栏推荐
猜你喜欢
Pytorch模型训练实用教程学习笔记:三、损失函数汇总
Pytorch模型训练实用教程学习笔记:一、数据加载和transforms方法总结
XSS range intermediate bypass
【多任务学习】Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts KDD18
nacos安装与配置
57:第五章:开发admin管理服务:10:开发【从MongoDB的GridFS中,获取文件,接口】;(从GridFS中,获取文件的SOP)(不使用MongoDB的服务,可以排除其自动加载类)
Gradle系列——Gradle文件操作,Gradle依赖(基于Gradle文档7.5)day3-1
ThreadLocal讲义
openresty 动态黑白名单
XSS靶场中级绕过
随机推荐
我的驾照考试笔记(4)
小数据如何学习?吉大最新《小数据学习》综述,26页pdf涵盖269页文献阐述小数据学习理论、方法与应用
使用常见问题解答软件的好处有哪些?
大整数相加,相减,相乘,大整数与普通整数的相乘,相除
【webrtc】sigslot : 继承has_slot 及相关流程和逻辑
数据库系统原理与应用教程(071)—— MySQL 练习题:操作题 110-120(十五):综合练习
1个小时!从零制作一个! AI图片识别WEB应用!
工作5年,测试用例都设计不好?来看看大神的用例设计总结
{ValueError}Number of classes, 1, does not match size of target_names, 2. Tr
漏刻有时文档系统之XE培训系统二次开发配置手册
为你的“架构”安排定期体检吧!
XSS range intermediate bypass
Creo5.0 rough hexagon is how to draw
BN BatchNorm + BatchNorm的替代新方法KNConvNets
二维、三维、四维矩阵每个维度含义解释
nacos installation and configuration
KDD2022 | Self-Supervised Hypergraph Transformer Recommendation System
【Untitled】
Risc-v Process Attack
CMake教程——Leeds_Garden