当前位置:网站首页>Comparison and summary of Tensorflow2 and Pytorch in terms of basic operations of tensor Tensor
Comparison and summary of Tensorflow2 and Pytorch in terms of basic operations of tensor Tensor
2022-08-05 05:38:00 【takedachia】
强迫症,整理了一下 Tensorflow2 与 Pytorch 在 张量Tensor Differences in basic operations.
Tensorflow版本:2.0
Pytorch版本:1.2
本文涉及代码:我的GitHub
import tensorflow as tf
import torch
torch.set_printoptions(precision=16)
import numpy as np
Tensor基本操作1
创建标量
# Create a numeric scalar tf.constant、torch.tensor
b = tf.constant(1.2)
a = torch.tensor(1.3)
display(b)
display(a)
创建tensor(从列表创建)
可以看到tensorflow中的tensorThe print information is very rich,包括了shape、dtype、numpy的data
b2 = tf.constant([1,2,3,4,5.5])
a2 = torch.tensor([1,2,3,4,5.5])
b2, a2
生成随机tensor(Press the specified shape、According to standard normal distribution)
tf.random.normal
torch.randn
b3 = tf.random.normal([2,3,2])
a3 = torch.normal(mean=0, std=1, size=(3,4,2))
# a3 = torch.randn(3,4,2) 同上
b3, a3, a3.dtype, a3.shape
Create a string typetensor
字符串类型为tensorflow独有.
# 字符串
b4 = tf.constant('hello world')
# 在pytorchdoes not have its own string type in
b4
# 有了string类型,You can use various methods of strings,详见:https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/strings
Create booleantensor
p5 = True
b5 = tf.constant(True)
a5 = torch.tensor(True)
# 可直接与pythonThe boolean type performs logical operations
b5, b5==p5, a5, a5.type(), p5==a5
dtypeType and numeric precision
# 数值类型,数值精度
b6 = tf.constant(np.pi, dtype=tf.float32)
b7 = tf.constant(np.pi, dtype=tf.half) # tf.half 等于 tf.float16
b8 = tf.constant(np.pi, dtype=tf.double) # tf.double 等于 tf.float32
a6 = torch.tensor(np.pi, dtype=torch.float32)
a7 = torch.tensor(np.pi, dtype=torch.half) # torch.half 等于 torch.float16
a8 = torch.tensor(np.pi, dtype=torch.double) # torch.double 等于 torch.float32
b6, b7, b8, a6, a7, a8, np.pi
dtype类型转换
tensorflow:cast
pytorch:to
# 数据类型转换 tf.cast、 to
b = tf.constant(np.pi, dtype=tf.double)
a = torch.tensor(np.pi, dtype=torch.double)
b9 = tf.cast(b, dtype=tf.float32)
a9 = a.to(dtype=torch.float32)
b9, a9
设置成 Optimizable variables(设置梯度)
tensorflowBy creating variablesVariable,Make it an optimizable variable.
转换成Variable后,可通过trainableDetermine whether the variable is trainable.
pytorch通过将tensorSet a gradient,Make it an optimizable variable.
requires_gradThe property can determine whether backpropagation is possible.
# Convert to an optimizable variable(tensorflow)
b10 = tf.constant([1,2,3,5.5])
b10_grad = tf.Variable(b10)
b11 = tf.Variable([3,4.5])
# 转换成Variable后,可通过trainableDetermine whether the variable is trainable
display(b10, b10_grad, b10_grad.trainable, b11)
print('-'*20)
# Set a gradient(pytorch)
a10 = torch.tensor([1,2,3,5.5])
a10.requires_grad_(True)
a11 = torch.tensor([1,2,8.99], requires_grad=True)
display(a10, a10.requires_grad, a11)
Tensor基本操作2
Tensor的创建方式
可传入List、np.array
tensorflow:convert_to_tensor()、tensor()、Variable()
pytorch:tensor()
特别注意:从np.array创建的tensor,默认dtype是 float64(对于tensorflow和pytorch都是如此).
# 从list、np.array创建
c1 = [1, 99, 2.3]
c2 = np.array([[3,4,5],[2.3,5.6,9.8],[2,3,2]])
# 注意:从np.array创建的tensor,默认dtype是 float64
b1 = tf.convert_to_tensor(c1)
b2 = tf.convert_to_tensor(c2)
b3 = tf.constant(c2)
b4 = tf.Variable(c2)
a1 = torch.tensor(c1)
a2 = torch.tensor(c2)
b1, b2, b3, b4, a1, a2
Generate all kinds of specialtensor(tensorflow)
全0型、全1型、Specify numeric padding(fill)、随机正态分布、随机均匀分布、生成序列
b3 = tf.zeros((3,4))
b4 = tf.ones((4,4), dtype=tf.float64)
b5 = tf.zeros_like(b2)
b6 = tf.fill((2,2), 8)
b7 = tf.random.normal((3,3), mean=0, stddev=1) # Generates values of the specified shape that are normally distributedtensor,默认mean=0,stddev=1
b7_uni = tf.random.uniform((3,3)) # 生成默认[0, 1)A uniformly distributed value for the interval
# 生成序列
b8 = tf.range(10)
b9 = tf.range(3,12, delta=2)
b3, b4, b5, b6, b7, b7_uni, b8, b9
Generate all kinds of specialtensor(pytorch)
全0型、全1型、Specify numeric padding(fill_)、随机正态分布、随机均匀分布、生成序列
a3 = torch.zeros((3,4))
a4 = torch.ones((4,4), dtype=torch.double)
a5 = torch.zeros_like(a2)
a6 = torch.zeros((2,2))
a6.fill_(8)
a7 = torch.normal(mean=0, std=1, size=(3,3)) # Generates values of the specified shape that are normally distributedtensor,默认mean=0,std=1
a7_2 = torch.randn(3,3) # 效果同上,Directly generate the values of the specified shape with a standard normal distributiontensor(mean=0,std=1)
a7_uni = torch.rand(3,3) # 生成默认[0, 1)A uniformly distributed value for the interval
# 生成序列
a8 = torch.range(0, 9) # pytorchStart and end positions must be specified,and is a closed interval
a9 = torch.range(3, 12, step=2)
a3, a4, a5, a6, a7, a7_2, a7_uni, a8, a9
下标访问(pytorch与tensorflow相同)
imgs_tf = tf.random.normal((2, 2, 4, 3))
print(imgs_tf)
print('-'*20)
print(imgs_tf[0][1][3][1])
print('-'*20)
print(imgs_tf[0,1,3,1]) # Equivalent to above
print('-'*10, 'Take the first image', '-'*10)
print(imgs_tf[0])
# print(imgs_tf[0, ::]) # Equivalent to above
print('-'*10, 'Take an image,隔行隔列采样', '-'*10)
print(imgs_tf[0, ::2, ::2, ::])
tensor的维度变换 / 改变视图(tensorflow)
常用reshape
imgs_tf2 = tf.random.normal((2, 2, 2, 2))
img1 = tf.reshape(imgs_tf2, (2,2,4))
print(img1)
print()
img1 = tf.reshape(imgs_tf2, (2,-1)) # 传入-1,Let the computer calculate the length by itself
print(img1)
print()
img1 = tf.reshape(imgs_tf2, (16,))
print(img1)
tensor的维度变换 / 改变视图(pytorch)
常用view
imgs_tf3 = torch.randn(2, 2, 2, 2)
img1 = imgs_tf3.view(2,2,4)
print(img1)
print()
img1 = imgs_tf3.view(2,-1) # 传入-1,Let the computer calculate the length by itself
print(img1)
print()
img1 = imgs_tf3.view(16,)
print(img1)
增加、删除、交换维度(tensorflow)
tf.expand_dims、tf.squeeze、tf.transpose
# 增加维度 tf.expand_dims
imgs_tf2 = tf.random.normal((3, 3, 3, 3))
img2 = tf.expand_dims(imgs_tf2, -1) # Extend to the last dimension
print(img2.shape)
img2 = tf.expand_dims(imgs_tf2, 4) # 往第4Post-dimensional expansion,Equivalent to above面
print(img2.shape)
# 删除维度 tf.squeeze
# You can only delete dimensions on a length of 1的维度
img2 = tf.squeeze(img2, -1)
print(img2.shape)
# 维度交换 tf.transpose
imgs_tf2 = tf.random.normal((3, 4, 5, 3))
img2 = tf.transpose(imgs_tf2, perm=[0, 2, 1, 3]) # 把第3维和第2维交换,在permIncoming the changed oneindex列表
print(img2.shape)
增加、删除、交换维度(pytorch)
torch.unsqueeze、torch.squeeze、torch.transpose、permute
# 增加维度 torch.unsqueeze
imgs_tf3 = torch.randn(3, 3, 3, 3)
img3 = torch.unsqueeze(imgs_tf3, -1) # Extend to the last dimension
print(img3.shape)
img3 = torch.unsqueeze(imgs_tf3, 4) # 往第4Post-dimensional expansion,Equivalent to above面
print(img3.shape)
# 删除维度 torch.squeeze
# You can only delete dimensions on a length of 1的维度
img3 = torch.squeeze(img3, -1)
print(img3.shape)
# 维度交换 torch.transpose、permute
imgs_tf3 = torch.randn(3, 4, 5, 3)
img3 = torch.transpose(imgs_tf3, 1, 2) # 把第3维和第2维交换,在permIncoming the changed oneindex列表
print(img3.shape)
# 方法2
img3 = imgs_tf3.permute([0, 2, 1, 3]) # 把第3维和第2维交换,直接在permuteIncoming the changed oneindex列表
print(img3.shape)
Tensor高级操作
张量的合并 —— 拼接、堆叠(tensorflow)
tf.concat、tf.stack
# 张量的合并 —— 拼接 tf.concat
a = tf.random.normal([4, 35, 8]) # For example, the data is :4个班级,35个学生,8门课的成绩
print(a.shape)
b = tf.random.normal([2, 35, 8]) # For example, the data is the data of another grade
print(b.shape)
# Stitching in the first dimension:Put all the classes together
c = tf.concat([a, b], axis=0)
print(c.shape)
print('-'*20)
# 张量的合并 —— 堆叠 tf.stack
# Stacking creates a new dimension,Combine data on new dimensions.Data to be stackedshape必须相同.
a = tf.random.normal([28,28,3])
b = tf.random.normal([28,28,3])
c = tf.stack([a, b], axis=0) # Stack the two pictures on top of each other,form an image set
print(c.shape)
张量的合并 —— 拼接、堆叠(pytorch)
torch.cat、torch.stack
# 张量的合并 —— 拼接 torch.cat
a2 = torch.randn(4, 35, 8) # For example, the data is :4个班级,35个学生,8门课的成绩
print(a2.shape)
b2 = torch.randn(2, 35, 8) # For example, the data is the data of another grade
print(b2.shape)
# Stitching in the first dimension:Put all the classes together
c2 = torch.cat([a2, b2], dim=0)
print(c2.shape)
print('-'*20)
# 张量的合并 —— 堆叠 torch.stack
# Stacking creates a new dimension,Combine data on new dimensions.Data to be stackedshape必须相同.
a2 = torch.randn(28,28,3)
b2 = torch.randn(28,28,3)
c2 = torch.stack([a2, b2], dim=0) # Stack the two pictures on top of each other,form an image set
print(c2.shape)
张量的分割(tensorflow)
tf.split、tf.unstack
# 张量的分割 tf.split、tf.unstack
# Splitting a tensor is splitting a tensor into multiple tensors,A split scheme needs to be specified(num_or_size_splits,传入numThat is, specify how many parts to divide into,传入一个ListThat is, it is specified to divide the plan according to the list)
a = tf.random.normal([4, 28, 28, 3]) # a是一个图像集,有4张图
result = tf.split(a, 2, axis=0) # Specifies to split into two
print(type(result)) # 返回一个列表,The list elements are twotensor
print(result[0].shape) # The first is one2image set of images
print('-'*20)
a = tf.random.normal([4, 28, 28, 3]) # a是一个图像集,有4张图
result = tf.split(a, [1,2,1,0], axis=0) # 指定分割方式,传入一个List
print(len(result)) # 被分割成了4份
print(result[-1].shape) # The last one has no pictures
print(result[-1]) # 是一个空tensor,only shape
print(result[0][0].shape) # The image of the first copy is accessible
print('-'*20)
# If you want to divide into equal parts,Only each1个元素,可用 unstack 方法.分完后,That dimension of segmentation disappears
a = tf.random.normal([4, 28, 28, 3])
result = tf.unstack(a, axis=3) # put the last dimension,That is, the channel is divided
print(result[0].shape) # 分完后,The divided dimension disappears
张量的分割(pytorch)
torch.chunk、torch.split、torch.unbind
# 张量的分割 torch.chunk、torch.split、torch.unbind
# Splitting a tensor is splitting a tensor into multiple tensors,A split scheme needs to be specified(Pass in a number or pass in a list)
# 传入一个ListThat is, it is specified to divide the plan according to the list
a2 = torch.randn(4, 28, 28, 3) # a是一个图像集,有4张图
result = torch.chunk(a2, 4, dim=0) # Specify split as4份
print(type(result)) # 返回一个元组
print(len(result))
print(result[0].shape) # The first is one1image set of images
print('-'*20)
result = torch.split(a2, 1, dim=0) # 注意!!Here the second parameter is to specify how many elements each section has!Instead of specifying how many portions to split into!
print(len(result))
print(result[0].shape) # The first is one1image set of images
print('-'*20)
a2 = torch.randn(4, 28, 28, 3) # a是一个图像集,有4张图
result = torch.split(a2, [1,2,1,0], dim=0) # 指定分割方式,传入一个List
print(len(result)) # 被分割成了4份
print(result[-1].shape) # The last one has no pictures
print(result[-1]) # 是一个空tensor,only shape
print('-'*20)
# If you want to divide into equal parts,Only each1个元素,可用 unbind 方法.
# 注意:分完后,That dimension of segmentation disappears.
a2 = torch.randn(4, 28, 28, 3)
result = torch.unbind(a2, dim=3) # put the last dimension,That is, the channel is divided
print(len(result))
print(result[0].shape) # 分完后,The divided dimension disappears
张量的数据统计(tensorflow)
data = np.random.random((3,4,5))
# 1、范数
b = tf.constant(data, dtype=tf.float32)
l1_b = tf.norm(b, 1)
print('L1范数:', l1_b)
l2_b = tf.norm(b, 2)
print('L2范数:', l2_b)
print('-'*20)
# 2、dimension maximum、最小值、均值、和
print('最大值:', tf.reduce_max(b, axis=0)) # 不指定axis,Returns the largest value in the tensor
print('The index of the maximum value:', tf.argmax(b)) # 默认axis=0,Returns the position of the largest value in this dimension(索引)
print('最小值:', tf.reduce_min(b, axis=0))
print('最小值所在索引:', tf.argmin(b))
print('(可指定维度)求均值:', tf.reduce_mean(b)) # 不指定axisFind the mean of all numbers
print('(可指定维度)求和:', tf.reduce_sum(b)) # 不指定axisFind the sum of all numbers
张量的数据统计(pytorch)
data = np.random.random((3,4,5))
# 1、范数
a = torch.tensor(data, dtype=torch.float32)
l1_a = torch.norm(a, 1)
print('L1范数:', l1_a)
l2_a = torch.norm(a, 2)
print('L2范数:', l2_a)
print('-'*20)
# 2、dimension maximum、最小值、均值、和
print('最大值和索引:', torch.max(a, dim=0)) # 不指定dim,Will return the largest value in the tensor
# Specify the dimensionsdim后,可以看到maxThe method returns two elements,One is the maximum value matrix,One is the index where the maximum value is located
print('最小值和索引:', torch.min(a, dim=0)) # 不指定dim,will return the smallest value in the tensor
# Specify the dimensionsdim后,可以看到maxThe method returns two elements,One is the minimum matrix,One is the index where the minimum value is located
print('The index of the maximum value can be found separately(The same is true for the index where the minimum value is located):', torch.argmax(a, dim=0))
print('(可指定维度)求均值:', torch.mean(a)) # 不指定dimFind the mean of all numbers
print('(可指定维度)求和:', torch.sum(a)) # 不指定dimFind the sum of all numbers
Tensor的填充(扩展)、Copy and clip
填充(扩展)Generally refers to filling out0.The padding operation is so that the convolution kernel can be effectively rolled to the position close to the edge of the image.
Replication refers to duplicating data at the specified dimension level,It can also be understood as stacking(tile).
填充 / 扩展(tensorflow)
tf.pad
b = tf.random.normal((2,28,28,3))
# 扩展为[2,32,32,3]
b_pad = tf.pad(b, [[0,0], [2,2], [2,2], [0,0]])
# pad方法的第二个参数:Pass in a nestList,Specify the expansion scheme for each dimension,For example, the second dimension is [2,2],It means to expand to the left2个,向Expansion on the right展2个.
print(b_pad.shape)
print(b_pad) # You can see that the default extension is0值
填充 / 扩展(pytorch)
torch.nn.functional.pad
a = torch.randn(2,28,28,3)
# 扩展为[2,32,32,3]
# 在pytorch中,使用:torch.nn.functional.pad(input, pad, mode='constant', value=0)
# 文档地址:https://pytorch.org/docs/1.2.0/nn.functional.html?highlight=pad#torch.nn.functional.pad
# pad方法的第二个参数:传入一个Tuple,Specify the expansion scheme for each dimension.
import torch.nn.functional as F
a_pad = F.pad(a, (0,0, 2,2, 2,2))
# 可以看到第二个参数,The number of extensions for each dimension passed in,Two values are a pair;Corresponding to the penultimate from left to right1个维度,倒数第2个维度...往后类推.
# Like what I passed in(0,0, 2,2, 2,2),Means last1A dimension is expanded to the left0,right bracket0;倒数第2A dimension is expanded to the left2,Expansion on the right2;倒数第3A dimension is expanded to the left2,Expansion on the right2
print(a_pad.shape)
# 默认扩展0值
复制(tensorflow)
tf.tile
# Padding is duplicating data at the specified dimension level,Note that replicated data is actually augmented data
b = tf.random.normal([2,5])
print(b)
b = tf.tile(b, multiples=[2, 1]) # 第二个参数multiple,Specifies the replication multiple for each dimension.For example here refers to the first1dimensions are copied as2倍,第二个维度为1倍(即不变)
print(b)
复制(pytorch)
tensor.repeat
# Padding is duplicating data at the dimension level,Note that replicated dimension data is actually augmented data
a = torch.randn(2,5)
print(a)
# 在pytorch 1.2中没有tile方法,可用repeat方法
a = a.repeat(2,1) # 直接对tensor使用repeat方法,The incoming parameter specifies the replication multiple of each dimension.For example here refers to the first1dimensions are copied as2倍,第二个维度为1倍(即不变)
print(a)
限幅(tensorflow)
Clipping refers to constraining the size of the value.
下限幅:maximum, 上限幅:minimum,It's a bit counterintuitive here.
# 下限幅:maximum, 上限幅:minimum(It's a bit counterintuitive here)
b = tf.range(10)
print(b)
print('Sets the lower clipping effect:', tf.maximum(b, 2)) # The lower limit is maximum,It can be understood as the largest and the smallest
print('Sets the upper clipping effect:', tf.minimum(b, 7)) # The upper limit is minimum,It can be understood as the minimum maximum value
print('Set the upper and lower limiter effects:', tf.clip_by_value(b, 3, 7))
限幅(pytorch)
torch.clamp
# 限幅:torch.clamp
a = torch.range(0, 10)
print(a)
print('Sets the lower clipping effect:', torch.clamp(a, min=2)) # 只设定min的话,That is, the lower limit is set
print('Sets the upper clipping effect:', torch.clamp(a, max=7)) # 只设定max的话,That is, the upper limit is set
print('Set the upper and lower limiter effects:', torch.clamp(a, min=2, max=7))
数据索引、数据收集
# Build some artificial random data
data = np.random.uniform(size=[2,10,4])
data_2 = np.random.uniform(size=[3,3,4])
data_3 = np.random.uniform(size=[2,3,4])
根据索引号收集数据(tensorflow)
We instruct slices to extract only regular data,gatherThe method can be used to extract irregular non-continuous data.
tf.gather(x, 索引方案, axis)
# 设有2个班级,每班10名学生,每名学生4门成绩:
b = tf.constant(data)
print(b)
print('*'*40)
print('Take the gradebook for the first class:', tf.gather(b, [0], axis=0)) # 传入List
print('The student number is extracted as 1、3、5、7、9号同学的成绩:', tf.gather(b, [0,2,4,6,8], axis=1)) # 传入List
print('*'*40)
b2 = tf.gather(b, [1], axis=0)
b3 = tf.gather(b2, [0,2,7], axis=1)
b4 = tf.gather(b3, [0,3], axis=2)
print('抽取第2个班级,第1、3、8名同学,第1、4门成绩:', b4)
print('抽取第2个班级,第1名同学,第1门成绩 和 第1个班级,第2名同学,第4门成绩:', tf.gather_nd(b, [[1,0,0],[0,1,3]])) # gather_nd中,The second parameter is passed in the dimension index scheme,Can be nested lists.
根据索引号收集数据(pytorch)
关于pytorch的gather函数的用法,可以见我的这篇文章.
但是,gatherFunctions are very anti-human to understand,We can collect data through indexes in other ways:
torch.index_select()
# 设有2个班级,每班10名学生,每名学生4门成绩:
a = torch.tensor(data)
print(a)
print('*'*40)
# 我们使用: torch.index_select(input, dim, index, out=None) → Tensor
# 注意,传入的index必须为tensor
index1 = torch.tensor([0])
print('Take the gradebook for the first class:', torch.index_select(a, dim=0, index=index1))
index2 = torch.tensor([0,2,4,6,8])
print('The student number is extracted as 1、3、5、7、9号同学的成绩:', torch.index_select(a, dim=1, index=index2))
print('*'*40)
a2 = torch.index_select(a, dim=0, index=torch.tensor([1]))
a3 = torch.index_select(a2, dim=1, index=torch.tensor([0,2,7]))
a4 = torch.index_select(a3, dim=2, index=torch.tensor([0,3]))
print('抽取第2个班级,第1、3、8名同学,第1、4门成绩:', a4)
# 关于 抽取第2个班级,第1名同学,第1门成绩 和 第1个班级,第2名同学,第4门成绩.This is not as convenient as using the slice index directly.
根据掩码(布尔索引)收集数据(tensorflow)
tf.boolean_mask(x, mask, axis)
# tf.boolean_mask(x, mask, axis)
# maskCan be in simple list form,Can also be a matrix of nested lists
b = tf.constant(data_2) # 设有3个班级,每班3名学生,每名学生4门成绩
print('提取第1、3个班级的数据:', tf.boolean_mask(b, [True, False, True], axis=0))
print()
b = tf.constant(data_3)
print('提取第1个班级,第1,2grade data of students;提取第2个班级,第2,3grade data of students:',
tf.boolean_mask(b, [[True, True, False], [False, True, True]]))
print('提取第1个班级,第2名同学的第3Door grade data',
tf.boolean_mask(b, [[[False, False, False, False],
[False, False, True, False],
[False, False, False, False]],
[[False, False, False, False],
[False, False, False, False],
[False, False, False, False]]]))
根据掩码(布尔索引)收集数据(pytorch)
我们可以使用pytorch的masked_select()函数.
But it can only pass in the global index,Dimensions cannot be specified,Therefore, the index tensor needs to be expanded by itself.
同时,返回的tensoris a stretched one-dimensionaltensor,So this method doesn't work well.
# torch.masked_select(input, mask, out=None) → Tensor
# maskis one with the originaltensor同形状的tensor
a = torch.tensor(data_2) # 设有3个班级,每班3名学生,每名学生4门成绩
# 由于pytorch的masked_selectThe function can only pass the global index,Dimensions cannot be specified,Therefore, the index tensor needs to be expanded by itself.
# 同时,返回的tensoris a stretched one-dimensionaltensor.
index1 = torch.tensor([True, False, True])
index1 = torch.stack([index1 for _ in range(3)], dim=-1)
index1 = torch.stack([index1 for _ in range(4)], dim=-1)
print(index1.shape == a.shape)
print('提取第1、3个班级的数据:', torch.masked_select(a, index1))
# 所以,pytorch的masked_select方法并不好用.
a = torch.tensor(data_3)
print('提取第1个班级,第2名同学的第3Door grade data',
torch.masked_select(a, torch.tensor([[[False, False, False, False],
[False, False, True, False],
[False, False, False, False]],
[[False, False, False, False],
[False, False, False, False],
[False, False, False, False]]])))
Collect data based on conditions(tensorflow)
tf.where
# Collect data based on conditions tensorflow
# tf.where(cond, a, b)
# cond是布尔索引
# 如果 cond[i]=True, x[i]=a[i]
# 如果 cond[i]=False, x[i]=b[i]
# 当a, bWhen there is no definite value,返回cond中所有TrueThe index coordinates of the element.
b1 = tf.ones([3,3])
b2 = tf.zeros([3,3])
cond = [[True,False,False], [False,True,False], [False,False,True]]
print(tf.where(cond, b1, b2))
Collect data based on conditions(pytorch)
torch.where
# Basically the same heretensorflow
a1 = torch.ones([3,3])
a2 = torch.zeros([3,3])
cond = torch.tensor([[True,False,False], [False,True,False], [False,False,True]])
print(torch.where(cond, a1, a2))
边栏推荐
猜你喜欢
随机推荐
【数据库和SQL学习笔记】10.(T-SQL语言)函数、存储过程、触发器
学习总结week2_4
如何编写一个优雅的Shell脚本(二)
【Reading】Long-term update
ES6基础语法
如何编写一个优雅的Shell脚本(一)
Service
Thread handler handle IntentServvice handlerThread
Flink EventTime和Watermarks案例分析
【Pytorch学习笔记】8.训练类别不均衡数据时,如何使用WeightedRandomSampler(权重采样器)
Flink Broadcast 广播变量
Web Component-处理数据
[Let's pass 14] A day in the study room
Spark ML学习相关资料整理
门徒Disciples体系:致力于成为“DAO世界”中的集大成者。
数据库期末考试,选择、判断、填空题汇总
Tensorflow踩坑笔记,记录各种报错和解决方法
flink部署操作-flink standalone集群安装部署
vscode+pytorch use experience record (personal record + irregular update)
spark-DataFrame数据插入mysql性能优化