当前位置:网站首页>PyTorch函数中的__call__和forward函数
PyTorch函数中的__call__和forward函数
2022-07-02 17:50:00 【D.ziyu】
初学nn.Module,看不懂各种调用,后来看明白了,估计会忘,故写篇笔记记录
init & call
代码:
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
输出
分析:A进行类的实例化,生成对象a,这个过程自动调用_init_(),没有调用_call_()
上面的代码加一行
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
a(1)
输出
分析:a是对象,python中让对象有了像函数一样加括号(参数)的功能,使用这种功能时,自动调用_call_()
_ call_()中可以调用其它函数,如forward函数
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
res = self.forward(param)
return res + 2
def forward(self, input_):
print('forward 函数', input_)
return input_
a = A()
b = a(1)
print('结果b =',b)

分析:_call _()成功调用了forward(),且返回值给了b
另外我之前有个误解,以为该类的值只有参数声明了才能用,这是错误的
class A():
def __init__(self):
print('init函数')
self.a = 100 # 声明参数a
def __call__(self, param):
print('call 函数', param)
res = self.forward(param)
return res + 2
def forward(self, input_):
print('forward 函数', input_, self.a)
return input_
a = A()
b = a(1)
print('结果b =',b)
print(a.a)

nn.Module
看了上面的例子,就知道了_call _()的作用,那下面看更CNN的例子
from torch import nn
import torch
class Ding(nn.Module):
def __init__(self):
print('init')
super().__init__()
def forward(self, input):
output = input + 1
print("forward")
return output
dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)
结果:
分析:
这里并没有调用_call_() 和forward(),但还是显示了forward,原因是:Ding这个子类继承了父类nn.Module里的call函数,接下来去源码看
发现_call_调用了_call_impl这个函数,相当于起了个外号一样,那就去这个函数看


这里有很多参数,详细可见参考2。发现这里forward_call 要么是_slow_forward,要么是self.forward(),而这个_slow_forward()也会用self.forward()
所以: _call _()用了forward,而这个父类的forward在子类中重写了(简单代码)
当然,也可以重写__call__(),比如我们不让它使用forward()
from torch import nn
import torch
class Ding(nn.Module):
def __init__(self):
print('init')
super().__init__()
def __call__(self, input_):
print('重写call, 不用forward')
return 'hhh'
def forward(self, input):
output = input + 1
print("forward")
return output
dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)

总结
使用对象dzy(x)时,用了父类nn.Module的call函数,调用了forward,而这个forward又被我们在子类里重写了。
参考
https://blog.csdn.net/dss_dssssd/article/details/83750838
https://zhuanlan.zhihu.com/p/366461413
边栏推荐
- Slam | how to align timestamps?
- Obligatoire pour les débutants, cliquez sur deux boutons pour passer à un contenu différent
- ICDE 2023|TKDE Poster Session(CFP)
- Leetcode (154) -- find the minimum value II in the rotation sort array
- SLC、MLC、TLC 和 QLC NAND SSD 之间的区别:哪个更好?
- FastDFS安装
- Learn the knowledge points of eight part essay ~ ~ 1
- 论文导读 | 机器学习在数据库基数估计中的应用
- 仿京东放大镜效果(pink老师版)
- The R language dplyr package rowwise function and mutate function calculate the maximum value of multiple data columns in each row in the dataframe data, and generate the data column (row maximum) cor
猜你喜欢

【测试开发】软件测试—概念篇

Stm32g0 USB DFU upgrade verification error -2

彻底搞懂基于Open3D的点云处理教程!

新加坡暑假旅遊攻略:一天玩轉新加坡聖淘沙島

Learning summary of MySQL advanced 6: concept and understanding of index, detailed explanation of b+ tree generation process, comparison between MyISAM and InnoDB
![[fluent] dart data type (VaR data type | object data type)](/img/1b/fe2529af5f6663fad1fb7861f14ab5.jpg)
[fluent] dart data type (VaR data type | object data type)

深度学习数学基础

M2dgr: slam data set of multi-source and multi scene ground robot (ICRA 2022)

论文导读 | 关于将预训练语言模型作为知识库的分析与批评

Hospital online inquiry source code hospital video inquiry source code hospital applet source code
随机推荐
[Yugong series] July 2022 go teaching course 001 introduction to go language premise
Have you stepped on the nine common pits in the e-commerce system?
UML class diagram
Redis (7) -- database and expiration key
SLC、MLC、TLC 和 QLC NAND SSD 之间的区别:哪个更好?
[fluent] dart data type (VaR data type | object data type)
【JVM调优实战100例】03——JVM堆调优四例
Eliminate the yellow alarm light on IBM p750 small computer [easy to understand]
产品经理应具备的能力
Hospital online inquiry source code hospital video inquiry source code hospital applet source code
Stratégie touristique d'été de Singapour: un jour pour visiter l'île de San taosha à Singapour
文字编辑器 希望有错误的句子用红色标红,文字编辑器用了markdown
迷你高尔夫球场:伦敦休闲旅游好去处
CDN acceleration and breaking J anti-theft chain function
@Component cannot get Dao layer
Competence of product manager
页面标题组件
The R language dplyr package rowwise function and mutate function calculate the maximum value of multiple data columns in each row in the dataframe data, and generate the data column (row maximum) cor
[test development] takes you to know what software testing is
【JVM调优实战100例】01——JVM的介绍与程序计数器