当前位置:网站首页>【pytorch】1.6 tensor 基本运算
【pytorch】1.6 tensor 基本运算
2022-07-29 13:26:00 【Enzo 想砸电脑】
一、逐元素操作

1、 torch.addcdiv(t, v, t1, t2)
t1 与 t2 按元素除后,乘v 加t
import torch
t = torch.randn(1, 3)
t1 = torch.randn(3, 1)
t2 = torch.randn(1, 3)
a = t + 0.1 *(t1 / t2)
print(a)
# tensor([[-0.2492, -0.6960, 2.3492],
# [-0.1057, -0.3203, 2.2584],
# [-0.0774, -0.2463, 2.2405]])
b = torch.addcdiv(t, 0.1, t1, t2)
print(b)
# tensor([[-0.2492, -0.6960, 2.3492],
# [-0.1057, -0.3203, 2.2584],
# [-0.0774, -0.2463, 2.2405]])
2、torch.addcmul(t, v, t1, t2)
t1 与 t2 按元素乘后,乘v 加t
import torch
t = torch.randn(1, 3)
t1 = torch.randn(3, 1)
t2 = torch.randn(1, 3)
a = t + 0.1 * t1 * t2
print(a)
# tensor([[-0.4994, 1.4826, -0.3377],
# [-0.4893, 1.4880, -0.3338],
# [-0.5957, 1.4314, -0.3756]])
b = torch.addcmul(t, 0.1, t1, t2)
print(b)
# tensor([[-0.4994, 1.4826, -0.3377],
# [-0.4893, 1.4880, -0.3338],
# [-0.5957, 1.4314, -0.3756]])
3、torch.clamp(input, min, max)
将张量元素大小限制在指定区间范围内
import torch
x = torch.arange(1, 8)
y = torch.clamp(x, 2, 5)
print(y)
# tensor([2, 2, 3, 4, 5, 5, 5])
4、torch.ceil(input) 、torch.floor(input)
torch.ceil(input) :向上取整
torch.floor(input) :向下取整
import torch
torch.manual_seed(8)
x = torch.randn(3) * 10
y = torch.ceil(x)
z = torch.floor(x)
print(x)
print(y)
print(z)
二、归并操作

1、torch.cumprod(t, axis)
在指定维度对 t 进行累积 (cumprod:cumulative product)
import torch
a = torch.arange(1, 10).reshape(3, 3)
print(a)
b_x = torch.cumprod(a, dim=0) # 沿着y轴累积
print("\ncumulative product:\n", b_x)
b_y = torch.cumprod(a, dim=1) # 沿着x轴累积
print("\ncumulative product:\n", b_y)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
#
# cumulative product:
# tensor([[ 1, 2, 3],
# [ 4, 10, 18],
# [ 28, 80, 162]])
#
# cumulative product:
# tensor([[ 1, 2, 6],
# [ 4, 20, 120],
# [ 7, 56, 504]])
2、torch.cumsum(t, axis)
import torch
a = torch.linspace(0, 10, 6).view(2, 3)
b = a.sum(dim=0)
c = torch.cumsum(a, dim=0)
print(a)
print(b)
print(c)
# tensor([[ 0., 2., 4.],
# [ 6., 8., 10.]])
#
# tensor([ 6., 10., 14.])
#
# tensor([[ 0., 2., 4.],
# [ 6., 10., 14.]])
d = a.sum(dim=1)
e = torch.cumsum(a, dim=1)
print(d)
print(e)
# tensor([ 6., 24.])
#
# tensor([[ 0., 2., 6.],
# [ 6., 14., 24.]])
三、比较操作
比较操作一般是进行逐元素比较,有些是按指定方向比较
1、torch.eq(a, b) 元素是否相等
a = torch.tensor([[1, 2, 3]])
b = torch.tensor([[1, 2, 4]])
print(a.eq(b))
# tensor([[ True, True, False]])
print(torch.eq(a, b))
# tensor([[ True, True, False]])
2、torch.equal(a, b) tensor是否完全相等
比较两个张量的形状和各个元素是否都相等
a = torch.tensor([[1, 2, 3]])
b = torch.tensor([[1, 2, 4]])
print(a.equal(b)) # False
print(torch.equal(a, b)) # False
c = torch.tensor([[1, 2, 3]])
d = torch.tensor([[1, 2, 3]])
print(c.equal(d)) # True
print(torch.equal(c, d)) # True
3、torch.ge 、torch.le 、torch.gt、torch.lt 大于等于/小于等于/大于/小于
a = torch.tensor([[1, 2, 5]])
b = torch.tensor([[1, 3, 3]])
# 大于等于
print(a.ge(b)) # tensor([[ True, False, True]])
print(torch.ge(a, b)) # tensor([[ True, False, True]])
# 小于等于
print(a.le(b)) # tensor([[ True, True, False]])
print(torch.le(a, b)) # tensor([[ True, True, False]])
# 大于
print(a.gt(b)) # tensor([[False, False, True]])
print(torch.gt(a, b)) # tensor([[False, False, True]])
# 小于
print(a.lt(b)) # tensor([[False, True, False]])
print(torch.lt(a, b)) # tensor([[False, True, False]])
4、torch.max、torch.min 最大值、最小值
若指定axis, 则额外返回索引下标
a = torch.tensor([[1, 8, 3],
[2, 5, 3]])
print(a.max()) # tensor(8)
print(torch.min(a)) # tensor(1)
print(a.max(0))
# torch.return_types.max(
# values=tensor([2, 8, 3]),
# indices=tensor([1, 0, 0]))
print(torch.min(a, 0))
# torch.return_types.min(
# values=tensor([1, 5, 3]),
# indices=tensor([0, 1, 0]))
5、torch.topk 指定维度上最高的k个值
a = torch.tensor([[1, 8, 3],
[2, 5, 3]])
# 维度 1 上的最大的 2 个值
print(a.topk(2, 1))
# torch.return_types.topk(
# values=tensor([[8, 3],
# [5, 3]]),
# indices=tensor([[1, 2],
# [1, 2]]))
四、矩阵操作

1、torch.dot(a, b) 点积
a = torch.tensor([2, 3])
b = torch.tensor([3, 4])
print(torch.dot(a, b))
# tensor(18)
Torch的 dot 只能对两个为1D 的张量进行点积运算,否则会报错;Numpy中的dot无此限制。
a = torch.tensor([[2, 3],
[3, 4]])
b = torch.tensor([[3, 4],
[1, 2]])
print(torch.dot(a, b))

2、torch.mm(a, b) 矩阵乘法
3、torch.mul(a, b)、 a * b 逐元素相乘
a = torch.tensor([[2, 3],
[3, 4]])
b = torch.tensor([[3, 4],
[1, 2]])
print(torch.mm(a, b))
# tensor([[ 9, 14],
# [13, 20]])
print(torch.mul(a, b))
# tensor([[ 6, 12],
# [ 3, 8]])
print(a * b)
# tensor([[ 6, 12],
# [ 3, 8]])
4、torch.mv(a, b) 矩阵与向量乘法
torch.mv(a, b), 矩阵a为第一个参数,向量b为第二个参数,位置不能换,否则会报错
a = torch.tensor([[1, 2, 3],
[2, 3, 4]])
b = torch.tensor([1, 2, 3])
print(torch.mv(a, b))
# tensor([14, 20])
5、tensor.T 转置
a = torch.randint(10, (2, 3))
print(a)
# tensor([[6, 9, 6],
# [0, 4, 8]])
print(a.T)
# tensor([[6, 0],
# [9, 4],
# [6, 8]])
6、 torch.svd(a) 矩阵的SVD分解
a = torch.randn(2, 3)
print(torch.svd(a))
# torch.return_types.svd(
# U=tensor([[-0.5960, 0.8030],
# [-0.8030, -0.5960]]),
# S=tensor([3.3907, 1.0873]),
# V=tensor([[ 0.2531, 0.8347],
# [-0.7722, 0.4789],
# [-0.5828, -0.2721]]))
边栏推荐
- 常坐飞机的你,为什么老惦记着“升舱”?
- 如何使用MISRA改进嵌入式编程
- IJCAI 2022 outstanding papers published, China won two draft in 298 the first author
- 线程池拒绝策略详解
- 连接oracle数据库指令
- 线程池工作流程-图示
- What is the difference between the legendary server GOM engine and the GEE engine?
- 1124. 骑马修栅栏
- The core principles of electronic games
- Sentinel vs Hystrix 限流到底怎么选?(荣耀典藏版)
猜你喜欢
随机推荐
PAT 甲级 A1021 Deepest Root
Project Manager: Not bad!The SSO single sign-on code is written, and the sequence diagram is also drawn?
Nacos hierarchical storage model - the cluster configuration and NacosRule load balance
Super young!34-year-old professor, vice president of 985 Ace College!
九种方式,教你读取 resources 目录下的文件路径
JS_ deleting the invalid data in the array undefined '0' null false NaN
浅谈MES系统质量管理的方案
企业如何走出固定资产管理的困境?
BGP联邦综合实验
Linux下 mysql5.7的彻底卸载
关于知识付费的一些思考
25年来最经典的电影特效名场面
关闭线程池 shutdown 和 shutdownNow 的区别
今日睡眠质量记录没有
Leetcode66. 加一
理解yolov7网络结构
Understand the yolov7 network structure
从零开发一款相机APP, Day03: Camera 常用api和最新框架介绍
Network connection optimization for instant messaging mobile terminal development
推荐几款2022年好用的设备管理系统(软件)









