当前位置:网站首页>PyTorch函数中的__call__和forward函数
PyTorch函数中的__call__和forward函数
2022-07-02 17:50:00 【D.ziyu】
初学nn.Module,看不懂各种调用,后来看明白了,估计会忘,故写篇笔记记录
init & call
代码:
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
输出
分析:A进行类的实例化,生成对象a,这个过程自动调用_init_(),没有调用_call_()
上面的代码加一行
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
a(1)
输出
分析:a是对象,python中让对象有了像函数一样加括号(参数)的功能,使用这种功能时,自动调用_call_()
_ call_()中可以调用其它函数,如forward函数
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
res = self.forward(param)
return res + 2
def forward(self, input_):
print('forward 函数', input_)
return input_
a = A()
b = a(1)
print('结果b =',b)

分析:_call _()成功调用了forward(),且返回值给了b
另外我之前有个误解,以为该类的值只有参数声明了才能用,这是错误的
class A():
def __init__(self):
print('init函数')
self.a = 100 # 声明参数a
def __call__(self, param):
print('call 函数', param)
res = self.forward(param)
return res + 2
def forward(self, input_):
print('forward 函数', input_, self.a)
return input_
a = A()
b = a(1)
print('结果b =',b)
print(a.a)

nn.Module
看了上面的例子,就知道了_call _()的作用,那下面看更CNN的例子
from torch import nn
import torch
class Ding(nn.Module):
def __init__(self):
print('init')
super().__init__()
def forward(self, input):
output = input + 1
print("forward")
return output
dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)
结果:
分析:
这里并没有调用_call_() 和forward(),但还是显示了forward,原因是:Ding这个子类继承了父类nn.Module里的call函数,接下来去源码看
发现_call_调用了_call_impl这个函数,相当于起了个外号一样,那就去这个函数看


这里有很多参数,详细可见参考2。发现这里forward_call 要么是_slow_forward,要么是self.forward(),而这个_slow_forward()也会用self.forward()
所以: _call _()用了forward,而这个父类的forward在子类中重写了(简单代码)
当然,也可以重写__call__(),比如我们不让它使用forward()
from torch import nn
import torch
class Ding(nn.Module):
def __init__(self):
print('init')
super().__init__()
def __call__(self, input_):
print('重写call, 不用forward')
return 'hhh'
def forward(self, input):
output = input + 1
print("forward")
return output
dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)

总结
使用对象dzy(x)时,用了父类nn.Module的call函数,调用了forward,而这个forward又被我们在子类里重写了。
参考
https://blog.csdn.net/dss_dssssd/article/details/83750838
https://zhuanlan.zhihu.com/p/366461413
边栏推荐
- Deep learning mathematics foundation
- 【每日一题】第二天
- How to write controller layer code gracefully?
- yolov3 训练自己的数据集之生成train.txt
- 论文导读 | 机器学习在数据库基数估计中的应用
- @Component cannot get Dao layer
- The difference between promise and observable
- R language dplyr package filter function filters dataframe data. If the name of the data column (variable) to be filtered contains quotation marks, you need to use!! SYM syntax processing, otherwise n
- 第一次去曼谷旅游怎么玩?这份省钱攻略请收好
- 从list转化成map的时候,如果根据某一属性可能会导致key重复而异常,可以设置处理这种重复的方式
猜你喜欢

ICDE 2023|TKDE Poster Session(CFP)

The text editor hopes to mark the wrong sentences in red, and the text editor uses markdown

Redis (7) -- database and expiration key

拦截器与过滤器的区别

M2DGR:多源多场景 地面机器人SLAM数据集(ICRA 2022 )

Singapore summer tourism strategy: play Singapore Sentosa Island in one day

【测试开发】软件测试—概念篇

LightGroupButton* sender = static_cast<LightGroupButton*>(QObject::sender());
![[daily question] first day](/img/8c/f25cddb6ca86d44538c976fae13c6e.png)
[daily question] first day

性能测试如何创造业务价值
随机推荐
The second bullet of AI development and debugging series: the exploration journey of multi machine distributed debugging
[daily question] the next day
@Component cannot get Dao layer
谷歌官方回应:我们没有放弃TensorFlow,未来与JAX并肩发展
M2DGR:多源多场景 地面机器人SLAM数据集(ICRA 2022 )
Tips for material UV masking
How to use PS to extract image color and analyze color matching
Stm32g0 USB DFU upgrade verification error -2
The R language dplyr package rowwise function and mutate function calculate the maximum value of multiple data columns in each row in the dataframe data, and generate the data column (row maximum) cor
yolov3 训练自己的数据集之生成train.txt
SQL training 2
promise 和 Observable 的区别
Introduction to the paper | analysis and criticism of using the pre training language model as a knowledge base
options should NOT have additional properties
Singapore summer tourism strategy: play Singapore Sentosa Island in one day
How to delete the border of links in IE? [repeat] - how to remove borders around links in IE? [duplicate]
故障排查:kubectl报错ValidationError: unknown field \u00a0
在Tensorflow2中使用mnist_784数据集进行手写数字识别
思维意识转变是施工企业数字化转型成败的关键
Progress progress bar