当前位置:网站首页>tensorflow转pytorch笔记;tf.gather_nd(x,y)转pytorch
tensorflow转pytorch笔记;tf.gather_nd(x,y)转pytorch
2022-07-03 02:39:00 【strawberry47】
记录了将tensorflow转pytorch时,一些常用的函数转换:
不能直接转换
tf.transpose(input,[1, 0, 2])->input.permute([1, 0, 2])不能直接换成torch.transpose,因为操作不了多维tf.expand_dims(input), axis=1)->input.unsqueeze(1)tf.concat([content1,content2], axis=1->torch.cat((content1,content2), dim=1)记得把axis换成dimtf.tile(input, [2, 1])->input.repeat([2, 1])tf.range(10)->torch.arange(0)tf.reduce_sum(x, axis=1, keep_dims=True)->torch.sum(x,dim=1,keepdim=True)tf.clip_by_value(x, min, max)->torch.clamp(x, min, max)tf.multinomial(logits=a, num_samples=1)->torch.multinomial(input=a, num_samples=1, replacement=False)tf.equal(x, y)->torch.eq(x, y)tf.nn.embedding_lookup(W_fe, Feature_input + 1)->torch.index_select(W_fe, 0, Feature_input + 1)tf.one_hot()->functional.one_hot()
tf.gather_nd(x,y)转换
def gather_nd(self,params, indices):
''' 4D example params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices returns: tensor shaped [m_1, m_2, m_3, m_4] ND_example params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices returns: tensor shaped [m_1, ..., m_1] '''
out_shape = indices.shape[:-1]
indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
ndim = indices.shape[0]
indices = indices.long()
idx = torch.zeros_like(indices[0], device=indices.device).long()
m = 1
for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)
out = torch.take(params, idx)
return out.view(out_shape)
可以直接转换
tf.reshape()->torch.reshape()tf.log()tf.squeeze
边栏推荐
- [Hcia]No.15 Vlan间通信
- Gbase 8C system table PG_ cast
- [principles of multithreading and high concurrency: 1_cpu multi-level cache model]
- xiaodi-笔记
- [fluent] futurebuilder asynchronous programming (futurebuilder construction method | asyncsnapshot asynchronous calculation)
- 【教程】chrome关闭跨域策略cors、samesite,跨域带上cookie
- [translation] modern application load balancing with centralized control plane
- GBase 8c系统表-pg_am
- Face recognition 6-face_ recognition_ Py based on OpenCV, face detection and real-time tracking using Haar cascade and Dlib Library
- GBase 8c 创建用户/角色 示例二
猜你喜欢

Build a private cloud disk cloudrev

Can netstat still play like this?
![Error when installing MySQL in Linux: starting mysql The server quit without updating PID file ([FAILED]al/mysql/data/l.pid](/img/32/25771baad1ed06c5a592087df748f1.jpg)
Error when installing MySQL in Linux: starting mysql The server quit without updating PID file ([FAILED]al/mysql/data/l.pid

awk从入门到入土(0)awk概述

Awk from introduction to earth (0) overview of awk

Mathematical statistics -- Sampling and sampling distribution

Basic operation of binary tree (C language version)

Random shuffle note

Choose it when you decide
[advanced ROS] Lesson 6 recording and playback in ROS (rosbag)
随机推荐
Gbase 8C create user / role example 2
Wechat - developed by wechat official account Net core access
Hcip137-147 title + analysis
easyExcel
GBase 8c系统表pg_database
线程安全的单例模式
awk从入门到入土(2)认识awk内置变量和变量的使用
[fluent] JSON model conversion (JSON serialization tool | JSON manual serialization | writing dart model classes according to JSON | online automatic conversion of dart classes according to JSON)
Basic operation of binary tree (C language version)
Mathematical statistics -- Sampling and sampling distribution
Awk from entry to burial (1) awk first meeting
Apple releases MacOS 11.6.4 update: mainly security fixes
cvpr2022去雨去雾
COM and cn
MATLAB小技巧(24)RBF,GRNN,PNN-神经网络
【ROS进阶篇】第六讲 ROS中的录制与回放(rosbag)
HTB-Devel
[translation] flux is safe. Gain more confidence through fuzzy processing
awk从入门到入土(1)awk初次会面
Informatics Olympiad one general question bank 1006 a+b questions