当前位置:网站首页>简单介绍一下tensorflow与pytorch的相互转换(主要是tensorflow转pytorch)
简单介绍一下tensorflow与pytorch的相互转换(主要是tensorflow转pytorch)
2022-06-28 15:48:00 【wendy_ya】
本文以一段代码为例,简单介绍一下tensorflow与pytorch的相互转换(主要是tensorflow转pytorch),可能介绍的没有那么详细,仅供参考。
由于本人只熟悉pytorch,而对tensorflow一知半解,而代码经常遇到tensorflow,而我希望使用pytorch,因此简单介绍一下tensorflow转pytorch,可能存在诸多错误,希望轻喷~
1.变量预定义
在TensorFlow的世界里,变量的定义和初始化是分开的。
tensorflow中一般都是在开头预定义变量,声明其数据类型、形状等,在执行的时候再赋具体的值,如下图所示,而pytorch用到时才会定义,定义和变量初始化是合在一起的。
2.创建变量并初始化
tensorflow中利用tf.Variable创建变量并进行初始化,而pytorch中使用torch.tensor创建变量并进行初始化,如下图所示。
3.语句执行
在TensorFlow的世界里,变量的定义和初始化是分开的,所有关于图变量的赋值和计算都要通过tf.Session的run来进行。
sess.run([G_solver, G_loss_temp, MSE_loss],
feed_dict = {
X: X_mb, M: M_mb, H: H_mb})
而在pytorch中,并不需要通过run进行,赋值完了直接计算即可。
4.tensor
pytorch运算时要创建完的numpy数组转为tensor,如下:
if use_gpu is True:
X_mb = torch.tensor(X_mb, device="cuda")
M_mb = torch.tensor(M_mb, device="cuda")
H_mb = torch.tensor(H_mb, device="cuda")
else:
X_mb = torch.tensor(X_mb)
M_mb = torch.tensor(M_mb)
H_mb = torch.tensor(H_mb)
最后运行完还要将tensor数据类型转换回numpy数组:
if use_gpu is True:
imputed_data=imputed_data.cpu().detach().numpy()
else:
imputed_data=imputed_data.detach().numpy()
而tensorflow中不需要这种操作。
5.其他函数
在tensorflow中包含诸多函数是pytorch中没有的,但是都可以在其他库中找到类似,具体如下表所示。
| tensorflow中函数 | pytorch中代替(所在库) | 参数区别 |
|---|---|---|
| tf.sqrt | np.sqrt(numpy) | 完全相同 |
| tf.random_normal | np.random.normal(numpy) | tf.random_normal(shape = size, stddev = xavier_stddev) np.random.normal(size = size, scale = xavier_stddev) |
| tf.concat | torch.cat(torch) | inputs = tf.concat(values = [x, m], axis = 1) inputs = torch.cat(dim=1, tensors=[x, m]) |
| tf.nn.relu | F.relu(torch.nn.functional) | 完全相同 |
| tf.nn.sigmoid | torch.sigmoid(torch) | 完全相同 |
| tf.matmul | torch.matmul(torch) | 完全相同 |
| tf.reduce_mean | torch.mean(torch) | 完全相同 |
| tf.log | torch.log(torch) | 完全相同 |
| tf.zeros | np.zeros | 完全相同 |
| tf.train.AdamOptimizer | torch.optim.Adam(torch) | optimizer_D = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) optimizer_D = torch.optim.Adam(params=theta_D) |
【说明】:本文的介绍仅供参考,实际转换请多查阅相关资料,如果有能力,建议这两种深度学习框架都进行掌握~
边栏推荐
- openGauss内核:SQL解析过程分析
- Visual Studio 2010 configuring and using qt5.6.3
- 抖音实战~我关注的博主列表、关注、取关
- 北京有哪些牛逼的中小型公司?
- Go zero micro Service Practice Series (VII. How to optimize such a high demand)
- The past and present life of distributed cap theorem
- Azure Kinect微软摄像头Unity开发小结
- What! One command to get the surveillance?
- Soliciting articles and contributions - building a blog environment with a lightweight application server
- Realization of a springboard machine
猜你喜欢

openGauss内核:SQL解析过程分析

A new 25K byte from the Department showed me what the ceiling is

Analysis of PostgreSQL storage structure

【Spock】处理 Non-ASCII characters in an identifier
PostgreSQL enables grouping statistics by year, month, day, week, hour, minute and second

Technical secrets of ByteDance data platform: implementation and optimization of complex query based on Clickhouse

Application of mongodb in Tencent retail premium code

薅羊毛的机会了,点个“赚”即有机会赚取高额佣金
![The k-th element in the array [heap row + actual time complexity of heap building]](/img/69/bcafdcb09ffbf87246a03bcb9367aa.png)
The k-th element in the array [heap row + actual time complexity of heap building]

Visual Studio 2010 配置和使用Qt5.6.3
随机推荐
VS2013 帮助文档中没有 win32/com
[recommendation system] esmm model of multi task learning (updating)
Soliciting articles and contributions - building a blog environment with a lightweight application server
Fleet | background Discovery issue 3: Status Management
[leetcode] 13. Roman numeral to integer
成功迁移到云端需要采取的步骤
Qt5.5.1 configuring msvc2010 compiler and WinDbg debugger
Technical secrets of ByteDance data platform: implementation and optimization of complex query based on Clickhouse
【初学者必看】vlc实现的rtsp服务器及转储H264文件
Experiment 6 8255 parallel interface experiment [microcomputer principle] [experiment]
What are the most powerful small and medium-sized companies in Beijing?
今天睡眠质量记录80分
Ros21 lecture
What useful supplier management systems are available
MongoDB 在腾讯零售优码中的应用
【高并发基础】MySQL 不同事务隔离级别下的并发隐患及解决方案
隆重推出 Qodana:您最爱的 CI 的代码质量平台
10 years of testing experience, worthless in the face of the physiological age of 35
Classic model transformer
Gartner发布当前至2024年的五大隐私趋势