当前位置:网站首页>tensorflow转pytorch笔记;tf.gather_nd(x,y)转pytorch
tensorflow转pytorch笔记;tf.gather_nd(x,y)转pytorch
2022-07-03 02:39:00 【strawberry47】
记录了将tensorflow转pytorch时,一些常用的函数转换:
不能直接转换
tf.transpose(input,[1, 0, 2])->input.permute([1, 0, 2])不能直接换成torch.transpose,因为操作不了多维tf.expand_dims(input), axis=1)->input.unsqueeze(1)tf.concat([content1,content2], axis=1->torch.cat((content1,content2), dim=1)记得把axis换成dimtf.tile(input, [2, 1])->input.repeat([2, 1])tf.range(10)->torch.arange(0)tf.reduce_sum(x, axis=1, keep_dims=True)->torch.sum(x,dim=1,keepdim=True)tf.clip_by_value(x, min, max)->torch.clamp(x, min, max)tf.multinomial(logits=a, num_samples=1)->torch.multinomial(input=a, num_samples=1, replacement=False)tf.equal(x, y)->torch.eq(x, y)tf.nn.embedding_lookup(W_fe, Feature_input + 1)->torch.index_select(W_fe, 0, Feature_input + 1)tf.one_hot()->functional.one_hot()
tf.gather_nd(x,y)转换
def gather_nd(self,params, indices):
''' 4D example params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices returns: tensor shaped [m_1, m_2, m_3, m_4] ND_example params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices returns: tensor shaped [m_1, ..., m_1] '''
out_shape = indices.shape[:-1]
indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
ndim = indices.shape[0]
indices = indices.long()
idx = torch.zeros_like(indices[0], device=indices.device).long()
m = 1
for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)
out = torch.take(params, idx)
return out.view(out_shape)
可以直接转换
tf.reshape()->torch.reshape()tf.log()tf.squeeze
边栏推荐
猜你喜欢

The use of Flink CDC mongodb and the implementation of Flink SQL parsing complex nested JSON data in monggo

Your family must be very poor if you fight like this!
![[Flutter] dart: class; abstract class; factory; Class, abstract class, factory constructor](/img/06/ab333a4752de27eae2dd937cf579e2.png)
[Flutter] dart: class; abstract class; factory; Class, abstract class, factory constructor
![[fluent] JSON model conversion (JSON serialization tool | JSON manual serialization | writing dart model classes according to JSON | online automatic conversion of dart classes according to JSON)](/img/6a/ae44ddb090ce6373f04a550a15f973.jpg)
[fluent] JSON model conversion (JSON serialization tool | JSON manual serialization | writing dart model classes according to JSON | online automatic conversion of dart classes according to JSON)
![[Hcia]No.15 Vlan间通信](/img/59/a467c5920cbccb72040f39f719d701.jpg)
[Hcia]No.15 Vlan间通信

oauth2.0鉴权,登录访问 “/oauth/token”,请求头Authorization(basicToken)如何取值???

The Linux server needs to install the agent software EPS (agent) database

Pytest (6) -fixture (Firmware)
![ASP. Net core 6 framework unveiling example demonstration [02]: application development based on routing, MVC and grpc](/img/cb/145937a27ef08050a370d5a255215a.jpg)
ASP. Net core 6 framework unveiling example demonstration [02]: application development based on routing, MVC and grpc
![[flutter] example of asynchronous programming code between future and futurebuilder (futurebuilder constructor setting | handling flutter Chinese garbled | complete code example)](/img/04/88ce45d370a2e6052c2fce558aa531.jpg)
[flutter] example of asynchronous programming code between future and futurebuilder (futurebuilder constructor setting | handling flutter Chinese garbled | complete code example)
随机推荐
random shuffle注意
4. Classes and objects
GBase 8c系统表-pg_auth_members
Add automatic model generation function to hade
Cancer biopsy instruments and kits - market status and future development trends
Kubernetes family container housekeeper pod online Q & A?
Simple understanding of SVG
The Linux server needs to install the agent software EPS (agent) database
QT qcombobox add qccheckbox (drop-down list box insert check box, including source code + comments)
MUX VLAN Foundation
Add MDF database file to SQL Server database, and the error is reported
Wechat - developed by wechat official account Net core access
左值右指解释的比较好的
awk从入门到入土(3)awk内置函数printf和print实现格式化打印
Return a tree structure data
Gbase 8C system table PG_ class
How to change the panet layer in yolov5 to bifpn
sql server数据库添加 mdf数据库文件,遇到的报错
Gbase 8C system table PG_ cast
GBase 8c系统表-pg_am