当前位置:网站首页>Tensorflow to pytorch notes; tf. gather_ Nd (x, y) to pytorch
Tensorflow to pytorch notes; tf. gather_ Nd (x, y) to pytorch
2022-07-03 02:43:00 【strawberry47】
Recorded will tensorflow turn pytorch when , Some commonly used function conversions :
Can't convert directly
tf.transpose(input,[1, 0, 2])
->input.permute([1, 0, 2])
It cannot be directly replaced bytorch.transpose
, Because you can't operate multidimensionaltf.expand_dims(input), axis=1)
->input.unsqueeze(1)
tf.concat([content1,content2], axis=1
->torch.cat((content1,content2), dim=1)
Remember to axis Switch to dimtf.tile(input, [2, 1])
->input.repeat([2, 1])
tf.range(10)
->torch.arange(0)
tf.reduce_sum(x, axis=1, keep_dims=True)
->torch.sum(x,dim=1,keepdim=True)
tf.clip_by_value(x, min, max)
->torch.clamp(x, min, max)
tf.multinomial(logits=a, num_samples=1)
->torch.multinomial(input=a, num_samples=1, replacement=False)
tf.equal(x, y)
->torch.eq(x, y)
tf.nn.embedding_lookup(W_fe, Feature_input + 1)
->torch.index_select(W_fe, 0, Feature_input + 1)
tf.one_hot()
->functional.one_hot()
tf.gather_nd(x,y) transformation
def gather_nd(self,params, indices):
''' 4D example params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices returns: tensor shaped [m_1, m_2, m_3, m_4] ND_example params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices returns: tensor shaped [m_1, ..., m_1] '''
out_shape = indices.shape[:-1]
indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
ndim = indices.shape[0]
indices = indices.long()
idx = torch.zeros_like(indices[0], device=indices.device).long()
m = 1
for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)
out = torch.take(params, idx)
return out.view(out_shape)
Can convert directly
tf.reshape()
->torch.reshape()
tf.log()
tf.squeeze
边栏推荐
- Use optimization | points that can be optimized in recyclerview
- 搭建私有云盘 cloudreve
- leetcode540
- Source code analysis | resource loading resources
- [fluent] future asynchronous programming (introduction | then method | exception capture | async, await keywords | whencomplete method | timeout method)
- 【Flutter】shared_ Preferences local storage (introduction | install the shared_preferences plug-in | use the shared_preferences process)
- Gbase 8C trigger (I)
- HW-初始准备
- 处理数据集,使用LabelEncoder将所有id转换为从0开始
- Restcloud ETL cross database data aggregation operation
猜你喜欢
Pytest (6) -fixture (Firmware)
Didi programmers are despised by relatives: an annual salary of 800000 is not as good as two teachers
where 1=1 是什么意思
《MATLAB 神经网络43个案例分析》:第43章 神经网络高效编程技巧——基于MATLAB R2012b新版本特性的探讨
搭建私有云盘 cloudreve
错误Invalid bound statement (not found): com.ruoyi.stock.mapper.StockDetailMapper.xxxx解决
[translation] the background project has joined the CNCF incubator
Add automatic model generation function to hade
Basic operation of binary tree (C language version)
Random Shuffle attention
随机推荐
Cvpr2022 remove rain and fog
Deep Reinforcement Learning for Intelligent Transportation Systems: A Survey 论文阅读笔记
SqlServer行转列PIVOT
GBase 8c系统表-pg_aggregate
tensor中的append应该如何实现
[advanced ROS] Lesson 6 recording and playback in ROS (rosbag)
What does "where 1=1" mean
Pytest (6) -fixture (Firmware)
Linear rectification function relu and its variants in deep learning activation function
Xiaodi notes
JMeter performance test JDBC request (query database to obtain database data) use "suggestions collection"
tensorflow转pytorch笔记;tf.gather_nd(x,y)转pytorch
Random Shuffle attention
Informatics Olympiad one general question bank 1006 a+b questions
SQL statement
As a leader, how to control the code version and demand development when the epidemic comes| Community essay solicitation
Add automatic model generation function to hade
[hcia]no.15 communication between VLANs
Monitoring and management of JVM
The core idea of performance optimization, dry goods sharing