当前位置:网站首页>Tensorflow to pytorch notes; tf. gather_ Nd (x, y) to pytorch
Tensorflow to pytorch notes; tf. gather_ Nd (x, y) to pytorch
2022-07-03 02:43:00 【strawberry47】
Recorded will tensorflow turn pytorch when , Some commonly used function conversions :
Can't convert directly
tf.transpose(input,[1, 0, 2])
->input.permute([1, 0, 2])
It cannot be directly replaced bytorch.transpose
, Because you can't operate multidimensionaltf.expand_dims(input), axis=1)
->input.unsqueeze(1)
tf.concat([content1,content2], axis=1
->torch.cat((content1,content2), dim=1)
Remember to axis Switch to dimtf.tile(input, [2, 1])
->input.repeat([2, 1])
tf.range(10)
->torch.arange(0)
tf.reduce_sum(x, axis=1, keep_dims=True)
->torch.sum(x,dim=1,keepdim=True)
tf.clip_by_value(x, min, max)
->torch.clamp(x, min, max)
tf.multinomial(logits=a, num_samples=1)
->torch.multinomial(input=a, num_samples=1, replacement=False)
tf.equal(x, y)
->torch.eq(x, y)
tf.nn.embedding_lookup(W_fe, Feature_input + 1)
->torch.index_select(W_fe, 0, Feature_input + 1)
tf.one_hot()
->functional.one_hot()
tf.gather_nd(x,y) transformation
def gather_nd(self,params, indices):
''' 4D example params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices returns: tensor shaped [m_1, m_2, m_3, m_4] ND_example params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices returns: tensor shaped [m_1, ..., m_1] '''
out_shape = indices.shape[:-1]
indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
ndim = indices.shape[0]
indices = indices.long()
idx = torch.zeros_like(indices[0], device=indices.device).long()
m = 1
for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)
out = torch.take(params, idx)
return out.view(out_shape)
Can convert directly
tf.reshape()
->torch.reshape()
tf.log()
tf.squeeze
边栏推荐
- GBase 8c系统表-pg_class
- 搭建私有云盘 cloudreve
- Gbase 8C system table PG_ attribute
- The left value and the right finger explain better
- Your family must be very poor if you fight like this!
- 简单理解svg
- GBase 8c系统表-pg_attribute
- Getting started | jetpack hilt dependency injection framework
- easyExcel
- Kubernetes cluster log and efk architecture log scheme
猜你喜欢
Summary of interview project technology stack
Practice of traffic recording and playback in vivo
[principles of multithreading and high concurrency: 1_cpu multi-level cache model]
超好用的日志库 logzero
Today, it's time to copy the bottom!
[advanced ROS] Lesson 6 recording and playback in ROS (rosbag)
[Hcia]No.15 Vlan间通信
Tongda OA V12 process center
Kubernetes cluster log and efk architecture log scheme
Random Shuffle attention
随机推荐
JS的装箱和拆箱
[fluent] JSON model conversion (JSON serialization tool | JSON manual serialization | writing dart model classes according to JSON | online automatic conversion of dart classes according to JSON)
处理数据集,使用LabelEncoder将所有id转换为从0开始
What is the way out for children from poor families?
Deep Reinforcement Learning for Intelligent Transportation Systems: A Survey 论文阅读笔记
[flutter] example of asynchronous programming code between future and futurebuilder (futurebuilder constructor setting | handling flutter Chinese garbled | complete code example)
Didi programmers are despised by relatives: an annual salary of 800000 is not as good as two teachers
超好用的日志库 logzero
搭建私有云盘 cloudreve
面试项目技术栈总结
cvpr2022去雨去雾
Can netstat still play like this?
Classes and objects - initialization and cleanup of objects - constructor call rules
Interview stereotyped version
Create + register sub apps_ Define routes, global routes and sub routes
【教程】chrome关闭跨域策略cors、samesite,跨域带上cookie
MUX VLAN Foundation
Gbase 8C system table PG_ collation
Random Shuffle attention
GBase 8c系统表-pg_amproc