当前位置:网站首页>pytorch乘法以及广播机制
pytorch乘法以及广播机制
2022-07-27 17:10:00 【遨游的菜鸡】
1.乘法
乘法:*、torch.mul、torch.mm、torch.matmul
https://blog.csdn.net/da_kao_la/article/details/87484403
2. 广播机制
pytorch中的广播机制和numpy中的广播机制一样, 因为都是数组的广播机制
两个维度不同的Tensor可以相乘, 示例
a = torch.arange(0,6).reshape((6,))
''' tensor([0, 1, 2, 3, 4, 5]) shape: torch.Size([6]) ndim: 1 '''
b = torch.arange(0,12).reshape((2,6))
''' tensor([[ 0, 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10, 11]]) shape: torch.Size([2, 6]) ndim: 2 '''
# a和b的ndim不同, 但是可以element-wise相乘, 因为用到了广播机制
res = torch.mul(a,b)
''' tensor([[ 0, 1, 4, 9, 16, 25], [ 0, 7, 16, 27, 40, 55]]) shape: torch.Size([2, 6]) ndim: 2 '''
如何理解数组的广播机制
以数组A和数组B的相加为例, 其余数学运算同理
核心:如果相加的两个数组的shape不同, 就会触发广播机制, 1)程序会自动执行操作使得A.shape==B.shape, 2)对应位置进行相加
运算结果的shape是:A.shape和B.shape对应位置的最大值,比如:A.shape=(1,9,4),B.shape=(15,1,4),那么A+B的shape是(15,9,4)
有两种情况能够进行广播
1.A.ndim > B.ndim, 并且A.shape最后几个元素包含B.shape, 比如下面三种情况, 注意不要混淆ndim和shape这两个基本概念
- A.shape=(2,3,4,5), B.shape=(3,4,5)
- A.shape=(2,3,4,5), B.shape=(4,5)
- A.shape=(2,3,4,5), B.shape=(5)
2.A.ndim == B.ndim, 并且A.shape和B.shape对应位置的元素要么相同要么其中一个是1, 比如
- A.shape=(1,9,4), B.shape=(15,1,4)
- A.shape=(1,9,4), B.shape=(15,1,1)
下面分别进行举例
A.ndim 大于 B.ndim
# a.shape=(2,2,3,4)
a = np.arange(1,49).reshape((2,2,3,4))
# b.shape=(3,4)
b = np.arange(1,13).reshape((3,4))
# numpy会将b.shape调整至(2,2,3,4), 这一步相当于numpy自动实现np.tile(b,[2,2,1,1])
res = a + b
print('===================================')
print(a)
print(a.shape)
print('===================================')
print(b)
print(b.shape)
print('===================================')
print(res)
print(res.shape)
print('===================================')
print(a+b == a + np.tile(b,[2,2,1,1]) )
A.ndim 等于 B.ndim
#示例1
# a.shape=(4,3)
a = np.arange(12).reshape(4,3)
# b.shape=(4,1)
b = np.arange(4).reshape(4,1)
# numpy会将b.shape调整至(4,3), 这一步相当于numpy自动实现np.tile(b,[1,3])
res = a + b
print('===================================')
print(a)
print(a.shape)
print('===================================')
print(b)
print(b.shape)
print('===================================')
print(res)
print(res.shape)
print('===================================')
print((a+b == a + np.tile(b,[1,3])) ) # 打印结果都是True
#示例2
# a.shape=(1,9,4)
a = np.arange(1,37).reshape((1,9,4))
# b.shape=(15,1,4)
b = np.arange(1,61).reshape((15,1,4))
res = a + b
print('===================================')
# print(a)
print(a.shape)
print('===================================')
# print(b)
print(b.shape)
print('===================================')
# print(res)
print(res.shape)
print('===================================')
q = np.tile(a,[15,1,1]) + np.tile(b,[1,9,1])
print(q == res) # 打印结果都是True
边栏推荐
- Transaction log full problem handling in sqlserver 2008
- 揭秘高通超声波指纹被“贴膜破解”之谜
- A lock faster than read-write lock. Don't get to know it quickly
- Samsung will promote a number of risc-v architecture chips, and 5g millimeter wave RF chips will be the first to be adopted
- 【深度学习基础知识 - 42】逻辑回归详解
- Install Talib Library under Anaconda
- JS 事件监听 鼠标 键盘 表单 页面 onclick onkeydown onChange
- Complex number proof of solvability of regular 17 sided ruler and gauge drawing
- 【深度学习基础知识 - 37】解决正负样本不均衡 Focal Loss
- [basic knowledge of deep learning - 46] Bayesian theorem and conditional probability formula
猜你喜欢

Detailed explanation of the underlying data structure of redis

rxbinding

【深度学习基础知识 - 45】机器学习中常用的距离计算方法
![[basic knowledge of deep learning - 45] distance calculation methods commonly used in machine learning](/img/6c/b0c2ea667ac361c13d38c8f5e6e5f1.png)
[basic knowledge of deep learning - 45] distance calculation methods commonly used in machine learning

Original pw4203 step-down 1-3 lithium battery charging chip

FileOutputStream(文件储存)与FileInputStream(文件读取)

SystemService(系统服务)
Dry goods of technical practice | preliminary exploration of large-scale gbdt training

四大组件之ContentProvider

GridView(实现表格显示图标)
随机推荐
Surpass Huawei? Ericsson has won more than 75 5g commercial contracts
What's new in helix QAC 2022.2, the ace code static testing tool (2)
SystemService(系统服务)
嵌入式C语言结构体
[basic knowledge of deep learning - 43] concept of odds ratio
ReferenceError: __ dirname is not defined in ES module scope
Matplotlib (basic usage)
A low code development platform that brings high-value user experience
File operation protection
英特尔发布Horse Ridge芯片:22nm工艺,能够控制多个量子位
Use of jvisualvm
【深度学习基础知识 - 41】深度学习快速入门学习资料
开启和禁用hyper-v
RadioGroup(单选框)
Detailed interpretation of IEC104 protocol (I) protocol structure
JS 寻找所有节点sibling childNodes children
Complex number proof of solvability of regular 17 sided ruler and gauge drawing
rxbinding
【日常积累 - 06】查看cuda和cudnn版本
Application pool has been disabled