当前位置:网站首页>【pytorch】1.6 tensor 基本运算
【pytorch】1.6 tensor 基本运算
2022-07-29 13:26:00 【Enzo 想砸电脑】
一、逐元素操作
1、 torch.addcdiv(t, v, t1, t2)
t1 与 t2 按元素除后,乘v 加t
import torch
t = torch.randn(1, 3)
t1 = torch.randn(3, 1)
t2 = torch.randn(1, 3)
a = t + 0.1 *(t1 / t2)
print(a)
# tensor([[-0.2492, -0.6960, 2.3492],
# [-0.1057, -0.3203, 2.2584],
# [-0.0774, -0.2463, 2.2405]])
b = torch.addcdiv(t, 0.1, t1, t2)
print(b)
# tensor([[-0.2492, -0.6960, 2.3492],
# [-0.1057, -0.3203, 2.2584],
# [-0.0774, -0.2463, 2.2405]])
2、torch.addcmul(t, v, t1, t2)
t1 与 t2 按元素乘后,乘v 加t
import torch
t = torch.randn(1, 3)
t1 = torch.randn(3, 1)
t2 = torch.randn(1, 3)
a = t + 0.1 * t1 * t2
print(a)
# tensor([[-0.4994, 1.4826, -0.3377],
# [-0.4893, 1.4880, -0.3338],
# [-0.5957, 1.4314, -0.3756]])
b = torch.addcmul(t, 0.1, t1, t2)
print(b)
# tensor([[-0.4994, 1.4826, -0.3377],
# [-0.4893, 1.4880, -0.3338],
# [-0.5957, 1.4314, -0.3756]])
3、torch.clamp(input, min, max)
将张量元素大小限制在指定区间范围内
import torch
x = torch.arange(1, 8)
y = torch.clamp(x, 2, 5)
print(y)
# tensor([2, 2, 3, 4, 5, 5, 5])
4、torch.ceil(input) 、torch.floor(input)
torch.ceil(input) :向上取整
torch.floor(input) :向下取整
import torch
torch.manual_seed(8)
x = torch.randn(3) * 10
y = torch.ceil(x)
z = torch.floor(x)
print(x)
print(y)
print(z)
二、归并操作
1、torch.cumprod(t, axis)
在指定维度对 t 进行累积 (cumprod:cumulative product)
import torch
a = torch.arange(1, 10).reshape(3, 3)
print(a)
b_x = torch.cumprod(a, dim=0) # 沿着y轴累积
print("\ncumulative product:\n", b_x)
b_y = torch.cumprod(a, dim=1) # 沿着x轴累积
print("\ncumulative product:\n", b_y)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
#
# cumulative product:
# tensor([[ 1, 2, 3],
# [ 4, 10, 18],
# [ 28, 80, 162]])
#
# cumulative product:
# tensor([[ 1, 2, 6],
# [ 4, 20, 120],
# [ 7, 56, 504]])
2、torch.cumsum(t, axis)
import torch
a = torch.linspace(0, 10, 6).view(2, 3)
b = a.sum(dim=0)
c = torch.cumsum(a, dim=0)
print(a)
print(b)
print(c)
# tensor([[ 0., 2., 4.],
# [ 6., 8., 10.]])
#
# tensor([ 6., 10., 14.])
#
# tensor([[ 0., 2., 4.],
# [ 6., 10., 14.]])
d = a.sum(dim=1)
e = torch.cumsum(a, dim=1)
print(d)
print(e)
# tensor([ 6., 24.])
#
# tensor([[ 0., 2., 6.],
# [ 6., 14., 24.]])
三、比较操作
比较操作一般是进行逐元素比较,有些是按指定方向比较
1、torch.eq(a, b) 元素是否相等
a = torch.tensor([[1, 2, 3]])
b = torch.tensor([[1, 2, 4]])
print(a.eq(b))
# tensor([[ True, True, False]])
print(torch.eq(a, b))
# tensor([[ True, True, False]])
2、torch.equal(a, b) tensor是否完全相等
比较两个张量的形状和各个元素是否都相等
a = torch.tensor([[1, 2, 3]])
b = torch.tensor([[1, 2, 4]])
print(a.equal(b)) # False
print(torch.equal(a, b)) # False
c = torch.tensor([[1, 2, 3]])
d = torch.tensor([[1, 2, 3]])
print(c.equal(d)) # True
print(torch.equal(c, d)) # True
3、torch.ge 、torch.le 、torch.gt、torch.lt 大于等于/小于等于/大于/小于
a = torch.tensor([[1, 2, 5]])
b = torch.tensor([[1, 3, 3]])
# 大于等于
print(a.ge(b)) # tensor([[ True, False, True]])
print(torch.ge(a, b)) # tensor([[ True, False, True]])
# 小于等于
print(a.le(b)) # tensor([[ True, True, False]])
print(torch.le(a, b)) # tensor([[ True, True, False]])
# 大于
print(a.gt(b)) # tensor([[False, False, True]])
print(torch.gt(a, b)) # tensor([[False, False, True]])
# 小于
print(a.lt(b)) # tensor([[False, True, False]])
print(torch.lt(a, b)) # tensor([[False, True, False]])
4、torch.max、torch.min 最大值、最小值
若指定axis, 则额外返回索引下标
a = torch.tensor([[1, 8, 3],
[2, 5, 3]])
print(a.max()) # tensor(8)
print(torch.min(a)) # tensor(1)
print(a.max(0))
# torch.return_types.max(
# values=tensor([2, 8, 3]),
# indices=tensor([1, 0, 0]))
print(torch.min(a, 0))
# torch.return_types.min(
# values=tensor([1, 5, 3]),
# indices=tensor([0, 1, 0]))
5、torch.topk 指定维度上最高的k个值
a = torch.tensor([[1, 8, 3],
[2, 5, 3]])
# 维度 1 上的最大的 2 个值
print(a.topk(2, 1))
# torch.return_types.topk(
# values=tensor([[8, 3],
# [5, 3]]),
# indices=tensor([[1, 2],
# [1, 2]]))
四、矩阵操作
1、torch.dot(a, b) 点积
a = torch.tensor([2, 3])
b = torch.tensor([3, 4])
print(torch.dot(a, b))
# tensor(18)
Torch的 dot 只能对两个为1D 的张量进行点积运算,否则会报错;Numpy中的dot无此限制。
a = torch.tensor([[2, 3],
[3, 4]])
b = torch.tensor([[3, 4],
[1, 2]])
print(torch.dot(a, b))
2、torch.mm(a, b) 矩阵乘法
3、torch.mul(a, b)、 a * b 逐元素相乘
a = torch.tensor([[2, 3],
[3, 4]])
b = torch.tensor([[3, 4],
[1, 2]])
print(torch.mm(a, b))
# tensor([[ 9, 14],
# [13, 20]])
print(torch.mul(a, b))
# tensor([[ 6, 12],
# [ 3, 8]])
print(a * b)
# tensor([[ 6, 12],
# [ 3, 8]])
4、torch.mv(a, b) 矩阵与向量乘法
torch.mv(a, b), 矩阵a为第一个参数,向量b为第二个参数,位置不能换,否则会报错
a = torch.tensor([[1, 2, 3],
[2, 3, 4]])
b = torch.tensor([1, 2, 3])
print(torch.mv(a, b))
# tensor([14, 20])
5、tensor.T 转置
a = torch.randint(10, (2, 3))
print(a)
# tensor([[6, 9, 6],
# [0, 4, 8]])
print(a.T)
# tensor([[6, 0],
# [9, 4],
# [6, 8]])
6、 torch.svd(a) 矩阵的SVD分解
a = torch.randn(2, 3)
print(torch.svd(a))
# torch.return_types.svd(
# U=tensor([[-0.5960, 0.8030],
# [-0.8030, -0.5960]]),
# S=tensor([3.3907, 1.0873]),
# V=tensor([[ 0.2531, 0.8347],
# [-0.7722, 0.4789],
# [-0.5828, -0.2721]]))
边栏推荐
猜你喜欢
trivy如何从非关系型数据库查询数据
TAP 文章系列-10 | 从应用感知能力谈 TAP 的约定服务
MLX90640 红外热成像仪测温传感器模块开发笔记(九)
Leetcode65. 有效数字
从KEIL仿真界面导出数据的技巧
计算机专业面试进阶指南
微信小程序的登录
Project Manager: Not bad!The SSO single sign-on code is written, and the sequence diagram is also drawn?
Research on the thinking and application methods of the frontier of ESI research
项目经理:不错啊!SSO单点登录代码写出来了,把时序图也画一下?
随机推荐
TCP和UDP的基本认识
How to set the explosion rate of legendary humanoid?Humanoid increase tutorial
gdb调试常用概念整理
第二轮Okaleido Tiger热卖的背后,是背后生态机构战略支持
Gdb debugging common concepts finishing
多人协作开发出现代码冲突,如何合并代码?
2022年年中总结:行而不辍,未来可期
25年来最经典的电影特效名场面
性能优化竟白屏,难道真是我的锅?
JS_ deleting the invalid data in the array undefined '0' null false NaN
PytestFixture实战应用+Pytest.ini与conftest.py应用详解+Fixture及yield实现用例前置后置
2022年七夕情人节有什么值得推荐的礼物选择?实用且高级礼物推荐
grid的使用
inner join 与 left join 之间的区别
理解yolov7网络结构
PHP代码审计得这样由浅入深地学
Py之eli5:eli5库的简介、安装、使用方法之详细攻略
Sentinel vs Hystrix 限流到底怎么选?(荣耀典藏版)
"Industrial flaw detection depth study method" the latest 2022 research were reviewed
码蹄集 tourist