当前位置:网站首页>pytorch乘法以及广播机制
pytorch乘法以及广播机制
2022-07-27 17:10:00 【遨游的菜鸡】
1.乘法
乘法:*、torch.mul、torch.mm、torch.matmul
https://blog.csdn.net/da_kao_la/article/details/87484403
2. 广播机制
pytorch中的广播机制和numpy中的广播机制一样, 因为都是数组的广播机制
两个维度不同的Tensor可以相乘, 示例
a = torch.arange(0,6).reshape((6,))
''' tensor([0, 1, 2, 3, 4, 5]) shape: torch.Size([6]) ndim: 1 '''
b = torch.arange(0,12).reshape((2,6))
''' tensor([[ 0, 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10, 11]]) shape: torch.Size([2, 6]) ndim: 2 '''
# a和b的ndim不同, 但是可以element-wise相乘, 因为用到了广播机制
res = torch.mul(a,b)
''' tensor([[ 0, 1, 4, 9, 16, 25], [ 0, 7, 16, 27, 40, 55]]) shape: torch.Size([2, 6]) ndim: 2 '''
如何理解数组的广播机制
以数组A和数组B的相加为例, 其余数学运算同理
核心:如果相加的两个数组的shape不同, 就会触发广播机制, 1)程序会自动执行操作使得A.shape==B.shape, 2)对应位置进行相加
运算结果的shape是:A.shape和B.shape对应位置的最大值,比如:A.shape=(1,9,4),B.shape=(15,1,4),那么A+B的shape是(15,9,4)
有两种情况能够进行广播
1.A.ndim > B.ndim, 并且A.shape最后几个元素包含B.shape, 比如下面三种情况, 注意不要混淆ndim和shape这两个基本概念
- A.shape=(2,3,4,5), B.shape=(3,4,5)
- A.shape=(2,3,4,5), B.shape=(4,5)
- A.shape=(2,3,4,5), B.shape=(5)
2.A.ndim == B.ndim, 并且A.shape和B.shape对应位置的元素要么相同要么其中一个是1, 比如
- A.shape=(1,9,4), B.shape=(15,1,4)
- A.shape=(1,9,4), B.shape=(15,1,1)
下面分别进行举例
A.ndim 大于 B.ndim
# a.shape=(2,2,3,4)
a = np.arange(1,49).reshape((2,2,3,4))
# b.shape=(3,4)
b = np.arange(1,13).reshape((3,4))
# numpy会将b.shape调整至(2,2,3,4), 这一步相当于numpy自动实现np.tile(b,[2,2,1,1])
res = a + b
print('===================================')
print(a)
print(a.shape)
print('===================================')
print(b)
print(b.shape)
print('===================================')
print(res)
print(res.shape)
print('===================================')
print(a+b == a + np.tile(b,[2,2,1,1]) )
A.ndim 等于 B.ndim
#示例1
# a.shape=(4,3)
a = np.arange(12).reshape(4,3)
# b.shape=(4,1)
b = np.arange(4).reshape(4,1)
# numpy会将b.shape调整至(4,3), 这一步相当于numpy自动实现np.tile(b,[1,3])
res = a + b
print('===================================')
print(a)
print(a.shape)
print('===================================')
print(b)
print(b.shape)
print('===================================')
print(res)
print(res.shape)
print('===================================')
print((a+b == a + np.tile(b,[1,3])) ) # 打印结果都是True
#示例2
# a.shape=(1,9,4)
a = np.arange(1,37).reshape((1,9,4))
# b.shape=(15,1,4)
b = np.arange(1,61).reshape((15,1,4))
res = a + b
print('===================================')
# print(a)
print(a.shape)
print('===================================')
# print(b)
print(b.shape)
print('===================================')
# print(res)
print(res.shape)
print('===================================')
q = np.tile(a,[15,1,1]) + np.tile(b,[1,9,1])
print(q == res) # 打印结果都是True
边栏推荐
- IEC104 规约详细解读(二)交互流程以及协议解析
- Chinese character search Pinyin wechat applet project source code
- HDU1171_ Big event in HDU [01 backpack]
- Original pw4203 step-down 1-3 lithium battery charging chip
- What's new in helix QAC 2022.2, the ace code static testing tool (2)
- An unknown fastcgi error occurred in IIS: 0x80070005
- [basic knowledge of deep learning - 46] Bayesian theorem and conditional probability formula
- 反超华为?爱立信已拿下超过75份5G商用合同
- Adhering to the integration of software and hardware, one Hengke makes efforts to the intelligent educational robot market
- [basic knowledge of deep learning - 37] solve the imbalance between positive and negative samples
猜你喜欢

SQlife(数据库)

Fabric上搭建Hyperledger caliper进行性能测试

Intent(有无返回值得跳转)

Influxdb series (IV) TSM engine (storage principle)

IIS 发生未知FastCGI错误:0x80070005

四大组件之ContentProvider

RadioGroup(单选框)

A lock faster than read-write lock. Don't get to know it quickly

Under the heat wave of Web3.0, the ecological shock of Mensa struck

Complex number proof of solvability of regular 17 sided ruler and gauge drawing
随机推荐
【深度学习基础知识 - 49】Kmeans
二叉搜索树
【深度学习基础知识 - 50】PCA降维 主成成分分析
Introduction to socke programming
Flink 算子简介
jvisualvm的使用
ContextMenu(上下文菜单)
HDU1171_ Big event in HDU [01 backpack]
Intel's process roadmap for the next 10 years is exposed: 1.4nm process will be launched in 2029! How?
Oppo released the first AR glasses and announced that it would invest 50billion in research and development in the next three years
Embedded C language loop deployment
PyTorch报CUDA error: no kernel image is available for execution on the device 错误
Debian recaptured the "debian.community" domain name, but it's still not good to stop and rest
The first in the field of mobile phone chip design in the world! Ziguang zhanrui won the international certification of tmmi4
嵌入式C语言指针别名
Flink简介以及运行架构
[daily accumulation - 07] CUDA multi version switching
访问控制
【深度学习基础知识 - 39】BN、LN、WN的比较
【深度学习基础知识 - 44】逻辑回归实现多分类的方法