当前位置:网站首页>简单介绍一下tensorflow与pytorch的相互转换(主要是tensorflow转pytorch)
简单介绍一下tensorflow与pytorch的相互转换(主要是tensorflow转pytorch)
2022-06-28 15:48:00 【wendy_ya】
本文以一段代码为例,简单介绍一下tensorflow与pytorch的相互转换(主要是tensorflow转pytorch),可能介绍的没有那么详细,仅供参考。
由于本人只熟悉pytorch,而对tensorflow一知半解,而代码经常遇到tensorflow,而我希望使用pytorch,因此简单介绍一下tensorflow转pytorch,可能存在诸多错误,希望轻喷~
1.变量预定义
在TensorFlow的世界里,变量的定义和初始化是分开的。
tensorflow中一般都是在开头预定义变量,声明其数据类型、形状等,在执行的时候再赋具体的值,如下图所示,而pytorch用到时才会定义,定义和变量初始化是合在一起的。
2.创建变量并初始化
tensorflow中利用tf.Variable创建变量并进行初始化,而pytorch中使用torch.tensor创建变量并进行初始化,如下图所示。
3.语句执行
在TensorFlow的世界里,变量的定义和初始化是分开的,所有关于图变量的赋值和计算都要通过tf.Session的run来进行。
sess.run([G_solver, G_loss_temp, MSE_loss],
feed_dict = {
X: X_mb, M: M_mb, H: H_mb})
而在pytorch中,并不需要通过run进行,赋值完了直接计算即可。
4.tensor
pytorch运算时要创建完的numpy数组转为tensor,如下:
if use_gpu is True:
X_mb = torch.tensor(X_mb, device="cuda")
M_mb = torch.tensor(M_mb, device="cuda")
H_mb = torch.tensor(H_mb, device="cuda")
else:
X_mb = torch.tensor(X_mb)
M_mb = torch.tensor(M_mb)
H_mb = torch.tensor(H_mb)
最后运行完还要将tensor数据类型转换回numpy数组:
if use_gpu is True:
imputed_data=imputed_data.cpu().detach().numpy()
else:
imputed_data=imputed_data.detach().numpy()
而tensorflow中不需要这种操作。
5.其他函数
在tensorflow中包含诸多函数是pytorch中没有的,但是都可以在其他库中找到类似,具体如下表所示。
| tensorflow中函数 | pytorch中代替(所在库) | 参数区别 |
|---|---|---|
| tf.sqrt | np.sqrt(numpy) | 完全相同 |
| tf.random_normal | np.random.normal(numpy) | tf.random_normal(shape = size, stddev = xavier_stddev) np.random.normal(size = size, scale = xavier_stddev) |
| tf.concat | torch.cat(torch) | inputs = tf.concat(values = [x, m], axis = 1) inputs = torch.cat(dim=1, tensors=[x, m]) |
| tf.nn.relu | F.relu(torch.nn.functional) | 完全相同 |
| tf.nn.sigmoid | torch.sigmoid(torch) | 完全相同 |
| tf.matmul | torch.matmul(torch) | 完全相同 |
| tf.reduce_mean | torch.mean(torch) | 完全相同 |
| tf.log | torch.log(torch) | 完全相同 |
| tf.zeros | np.zeros | 完全相同 |
| tf.train.AdamOptimizer | torch.optim.Adam(torch) | optimizer_D = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) optimizer_D = torch.optim.Adam(params=theta_D) |
【说明】:本文的介绍仅供参考,实际转换请多查阅相关资料,如果有能力,建议这两种深度学习框架都进行掌握~
边栏推荐
- go-zero 微服务实战系列(七、请求量这么高该如何优化)
- Expand Disk C (allocate the memory of disk d to Disk C)
- 【Proteus仿真】L297驱动步进电机
- 大神详解开源 BUFF 增益攻略丨直播讲座
- Visual Studio 2019软件安装包和安装教程
- Qt5.5.1 configuring msvc2010 compiler and WinDbg debugger
- Jenkins的安装及使用
- 【LeetCode】13、罗马数字转整数
- Do not use short circuit logic to write STL sorter multi condition comparison
- Notes to distributed theory
猜你喜欢

Fleet |「後臺探秘」第 3 期:狀態管理

S2b2c system website solution for kitchen and bathroom electrical appliance industry: create s2b2c platform Omni channel commercial system

Curve 替换 Ceph 在网易云音乐的实践

See how the interface control devaxpress WinForms creates a virtual keyboard

讲师征集令 | Apache DolphinScheduler Meetup分享嘉宾,期待你的议题和声音!
![Experiment 6 8255 parallel interface experiment [microcomputer principle] [experiment]](/img/70/394ccf6e08a0774acade1eb1b8bf00.png)
Experiment 6 8255 parallel interface experiment [microcomputer principle] [experiment]

What! 一条命令搞定监控?

抖音实战~我关注的博主列表、关注、取关

Openharmony - detailed source code of Kernel Object Events

10年测试经验,在35岁的生理年龄面前,一文不值
随机推荐
No win32/com in vs2013 help document
【LeetCode】13、罗马数字转整数
机器学习之深度学习简介
Coding Devops helps Sinochem information to build a new generation of research efficiency platform and drive the new future of "online Sinochem"
[leetcode] 13. Roman numeral to integer
SaaS application management platform solution in the education industry: help enterprises realize the integration of operation and management
QT interface library
Visual Studio 2019软件安装包和安装教程
不要使用短路逻辑编写 stl sorter 多条件比较
The world has embraced Web3.0 one after another, and many countries have clearly begun to seize the initiative
国债与定期存款哪个更安全 两者之间有何区别
Flutter simply implements multilingual internationalization
Notes to distributed theory
Opengauss kernel: analysis of SQL parsing process
Today's sleep quality record is 80 points
Etcd可视化工具:Kstone简介(一)
Realization of a springboard machine
Experiment 6 8255 parallel interface experiment [microcomputer principle] [experiment]
Summary of language features of fluent dart
tablestore中可以使用sql查询可以查出表中所有的数据吗?