当前位置:网站首页>Tensorflow to pytorch notes; tf. gather_ Nd (x, y) to pytorch
Tensorflow to pytorch notes; tf. gather_ Nd (x, y) to pytorch
2022-07-03 02:43:00 【strawberry47】
Recorded will tensorflow turn pytorch when , Some commonly used function conversions :
Can't convert directly
tf.transpose(input,[1, 0, 2])->input.permute([1, 0, 2])It cannot be directly replaced bytorch.transpose, Because you can't operate multidimensionaltf.expand_dims(input), axis=1)->input.unsqueeze(1)tf.concat([content1,content2], axis=1->torch.cat((content1,content2), dim=1)Remember to axis Switch to dimtf.tile(input, [2, 1])->input.repeat([2, 1])tf.range(10)->torch.arange(0)tf.reduce_sum(x, axis=1, keep_dims=True)->torch.sum(x,dim=1,keepdim=True)tf.clip_by_value(x, min, max)->torch.clamp(x, min, max)tf.multinomial(logits=a, num_samples=1)->torch.multinomial(input=a, num_samples=1, replacement=False)tf.equal(x, y)->torch.eq(x, y)tf.nn.embedding_lookup(W_fe, Feature_input + 1)->torch.index_select(W_fe, 0, Feature_input + 1)tf.one_hot()->functional.one_hot()
tf.gather_nd(x,y) transformation
def gather_nd(self,params, indices):
''' 4D example params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices returns: tensor shaped [m_1, m_2, m_3, m_4] ND_example params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices returns: tensor shaped [m_1, ..., m_1] '''
out_shape = indices.shape[:-1]
indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
ndim = indices.shape[0]
indices = indices.long()
idx = torch.zeros_like(indices[0], device=indices.device).long()
m = 1
for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)
out = torch.take(params, idx)
return out.view(out_shape)
Can convert directly
tf.reshape()->torch.reshape()tf.log()tf.squeeze
边栏推荐
- Gbase 8C system table PG_ authid
- C语言中左值和右值的区别
- What does "where 1=1" mean
- Gbase 8C system table PG_ database
- javeScript 0.1 + 0.2 == 0.3的问题
- Today, it's time to copy the bottom!
- 定了,就选它
- oauth2.0鉴权,登录访问 “/oauth/token”,请求头Authorization(basicToken)如何取值???
- As a leader, how to control the code version and demand development when the epidemic comes| Community essay solicitation
- GBase 8c 触发器(一)
猜你喜欢

Kubernetes family container housekeeper pod online Q & A?

Pytest (6) -fixture (Firmware)

Didi programmers are despised by relatives: an annual salary of 800000 is not as good as two teachers

random shuffle注意

Deep learning: multi-layer perceptron and XOR problem (pytoch Implementation)

Mathematical statistics -- Sampling and sampling distribution

HW-初始准备

Xiaodi notes

The data in servlet is transferred to JSP page, and the problem cannot be displayed using El expression ${}

Deep Reinforcement Learning for Intelligent Transportation Systems: A Survey 论文阅读笔记
随机推荐
[translation] modern application load balancing with centralized control plane
A2L file parsing based on CAN bus (2)
GBase 8c系统表-pg_amproc
处理数据集,使用LabelEncoder将所有id转换为从0开始
GBase 8c系统表pg_cast
Javescript 0.1 + 0.2 = = 0.3 problem
面试八股文整理版
基于can总线的A2L文件解析(2)
cvpr2022去雨去雾
【翻译】Flux安全。通过模糊处理获得更多信心
GBase 8c 触发器(一)
GBase 8c 函数/存储过程参数(二)
[fluent] listview list (map method description of list set | vertical list | horizontal list | code example)
面试项目技术栈总结
SQL statement
GBase 8c系统表-pg_conversion
Concrete CMS vulnerability
Pytest (6) -fixture (Firmware)
Why choose a frame? What frame to choose
Gbase 8C system table PG_ cast