当前位置:网站首页>tensorflow转pytorch笔记;tf.gather_nd(x,y)转pytorch
tensorflow转pytorch笔记;tf.gather_nd(x,y)转pytorch
2022-07-03 02:39:00 【strawberry47】
记录了将tensorflow转pytorch时,一些常用的函数转换:
不能直接转换
tf.transpose(input,[1, 0, 2])
->input.permute([1, 0, 2])
不能直接换成torch.transpose
,因为操作不了多维tf.expand_dims(input), axis=1)
->input.unsqueeze(1)
tf.concat([content1,content2], axis=1
->torch.cat((content1,content2), dim=1)
记得把axis换成dimtf.tile(input, [2, 1])
->input.repeat([2, 1])
tf.range(10)
->torch.arange(0)
tf.reduce_sum(x, axis=1, keep_dims=True)
->torch.sum(x,dim=1,keepdim=True)
tf.clip_by_value(x, min, max)
->torch.clamp(x, min, max)
tf.multinomial(logits=a, num_samples=1)
->torch.multinomial(input=a, num_samples=1, replacement=False)
tf.equal(x, y)
->torch.eq(x, y)
tf.nn.embedding_lookup(W_fe, Feature_input + 1)
->torch.index_select(W_fe, 0, Feature_input + 1)
tf.one_hot()
->functional.one_hot()
tf.gather_nd(x,y)转换
def gather_nd(self,params, indices):
''' 4D example params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices returns: tensor shaped [m_1, m_2, m_3, m_4] ND_example params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices returns: tensor shaped [m_1, ..., m_1] '''
out_shape = indices.shape[:-1]
indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
ndim = indices.shape[0]
indices = indices.long()
idx = torch.zeros_like(indices[0], device=indices.device).long()
m = 1
for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)
out = torch.take(params, idx)
return out.view(out_shape)
可以直接转换
tf.reshape()
->torch.reshape()
tf.log()
tf.squeeze
边栏推荐
- 【教程】chrome關閉跨域策略cors、samesite,跨域帶上cookie
- Gbase 8C system table PG_ am
- Detailed analysis of micro service component sentinel (hystrix)
- leetcode540
- GBase 8c系统表-pg_auth_members
- GBase 8c系统表-pg_aggregate
- 【 tutoriel】 Chrome ferme les cors et les messages de la politique inter - domaines et apporte des cookies à travers les domaines
- MUX VLAN Foundation
- 【翻译】后台项目加入了CNCF孵化器
- Restcloud ETL cross database data aggregation operation
猜你喜欢
MATLAB小技巧(24)RBF,GRNN,PNN-神经网络
[translation] the background project has joined the CNCF incubator
SQL statement
Mathematical statistics -- Sampling and sampling distribution
Principle and application of database
错误Invalid bound statement (not found): com.ruoyi.stock.mapper.StockDetailMapper.xxxx解决
【Flutter】shared_ Preferences local storage (introduction | install the shared_preferences plug-in | use the shared_preferences process)
random shuffle注意
怎么将yolov5中的PANet层改为BiFPN
Oauth2.0 authentication, login and access "/oauth/token", how to get the value of request header authorization (basictoken)???
随机推荐
Mathematical statistics -- Sampling and sampling distribution
What does "where 1=1" mean
Baidu map - surrounding search
Principle and application of database
javeScript 0.1 + 0.2 == 0.3的问题
【 tutoriel】 Chrome ferme les cors et les messages de la politique inter - domaines et apporte des cookies à travers les domaines
GBase 8c 创建用户/角色 示例二
GBase 8c系统表-pg_amop
[shutter] bottom navigation bar page frame (bottomnavigationbar bottom navigation bar | pageview sliding page | bottom navigation and sliding page associated operation)
Codeforces Round #418 (Div. 2) D. An overnight dance in discotheque
sql server 查詢指定錶的錶結構
xiaodi-笔记
Pytest (6) -fixture (Firmware)
UDP receive queue and multiple initialization test
Strategy application of Dameng database
"Analysis of 43 cases of MATLAB neural network": Chapter 43 efficient programming skills of neural network -- Discussion Based on the characteristics of the new version of MATLAB r2012b
Oauth2.0 authentication, login and access "/oauth/token", how to get the value of request header authorization (basictoken)???
GBase 8c 触发器(一)
awk从入门到入土(2)认识awk内置变量和变量的使用
Add automatic model generation function to hade