当前位置:网站首页>PyTorch函数中的__call__和forward函数
PyTorch函数中的__call__和forward函数
2022-07-02 17:50:00 【D.ziyu】
初学nn.Module,看不懂各种调用,后来看明白了,估计会忘,故写篇笔记记录
init & call
代码:
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
输出
分析:A进行类的实例化,生成对象a,这个过程自动调用_init_(),没有调用_call_()
上面的代码加一行
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
a(1)
输出
分析:a是对象,python中让对象有了像函数一样加括号(参数)的功能,使用这种功能时,自动调用_call_()
_ call_()中可以调用其它函数,如forward函数
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
res = self.forward(param)
return res + 2
def forward(self, input_):
print('forward 函数', input_)
return input_
a = A()
b = a(1)
print('结果b =',b)
分析:_call _()成功调用了forward(),且返回值给了b
另外我之前有个误解,以为该类的值只有参数声明了才能用,这是错误的
class A():
def __init__(self):
print('init函数')
self.a = 100 # 声明参数a
def __call__(self, param):
print('call 函数', param)
res = self.forward(param)
return res + 2
def forward(self, input_):
print('forward 函数', input_, self.a)
return input_
a = A()
b = a(1)
print('结果b =',b)
print(a.a)
nn.Module
看了上面的例子,就知道了_call _()的作用,那下面看更CNN的例子
from torch import nn
import torch
class Ding(nn.Module):
def __init__(self):
print('init')
super().__init__()
def forward(self, input):
output = input + 1
print("forward")
return output
dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)
结果:
分析:
这里并没有调用_call_() 和forward(),但还是显示了forward,原因是:Ding这个子类继承了父类nn.Module里的call函数,接下来去源码看
发现_call_调用了_call_impl这个函数,相当于起了个外号一样,那就去这个函数看
这里有很多参数,详细可见参考2。发现这里forward_call 要么是_slow_forward,要么是self.forward(),而这个_slow_forward()也会用self.forward()
所以: _call _()用了forward,而这个父类的forward在子类中重写了(简单代码)
当然,也可以重写__call__(),比如我们不让它使用forward()
from torch import nn
import torch
class Ding(nn.Module):
def __init__(self):
print('init')
super().__init__()
def __call__(self, input_):
print('重写call, 不用forward')
return 'hhh'
def forward(self, input):
output = input + 1
print("forward")
return output
dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)
总结
使用对象dzy(x)时,用了父类nn.Module的call函数,调用了forward,而这个forward又被我们在子类里重写了。
参考
https://blog.csdn.net/dss_dssssd/article/details/83750838
https://zhuanlan.zhihu.com/p/366461413
边栏推荐
- [100 cases of JVM tuning practice] 02 - five cases of virtual machine stack and local method stack tuning
- R language dplyr package filter function filters dataframe data. If the name of the data column (variable) to be filtered contains quotation marks, you need to use!! SYM syntax processing, otherwise n
- 昨天阿里学长写了一个责任链模式,竟然出现了无数个bug
- 论文导读 | 机器学习在数据库基数估计中的应用
- ORA-01455: converting column overflows integer datatype
- Hongmeng's fourth learning
- The second bullet of AI development and debugging series: the exploration journey of multi machine distributed debugging
- Exness in-depth good article: dynamic series - Case Analysis of gold liquidity (V)
- Introduction to the paper | application of machine learning in database cardinality estimation
- 【每日一题】第一天
猜你喜欢
Have you stepped on the nine common pits in the e-commerce system?
[100 cases of JVM tuning practice] 01 - introduction of JVM and program counter
Mini Golf Course: a good place for leisure and tourism in London
SLC、MLC、TLC 和 QLC NAND SSD 之间的区别:哪个更好?
Kubernetes three open interfaces first sight
Mysql高级篇学习总结7:Mysql数据结构-Hash索引、AVL树、B树、B+树的对比
【JVM调优实战100例】01——JVM的介绍与程序计数器
Industrial software lecture - core technology analysis of 3D CAD design software - the second lecture of the Forum
The text editor hopes to mark the wrong sentences in red, and the text editor uses markdown
[daily question] the next day
随机推荐
Gstore weekly gstore source code analysis (4): black and white list configuration analysis of security mechanism
Redis (6) -- object and data structure
R language uses Cox of epidisplay package Display function obtains the summary statistical information of Cox regression model (risk rate HR, adjusted risk rate and its confidence interval, P value of
Use MNIST in tensorflow 2_ 784 data set for handwritten digit recognition
【测试开发】一文带你了解什么是软件测试
论文导读 | 关于将预训练语言模型作为知识库的分析与批评
Typical application of "stack" - expression evaluation (implemented in C language)
Learning summary of MySQL advanced 6: concept and understanding of index, detailed explanation of b+ tree generation process, comparison between MyISAM and InnoDB
《病人家属,请来一下》读书笔记
Exness in-depth good article: dynamic series - Case Analysis of gold liquidity (V)
Mysql高级篇学习总结8:InnoDB数据存储结构页的概述、页的内部结构、行格式
The student Tiktok publicized that his alma mater was roast about "reducing the seal of enrollment". Netizen: hahahahahahahaha
【JVM调优实战100例】02——虚拟机栈与本地方法栈调优五例
[论文阅读] CA-Net: Leveraging Contextual Features for Lung Cancer Prediction
Google's official response: we have not given up tensorflow and will develop side by side with Jax in the future
R语言ggplot2可视化分面图(facet):gganimate包基于transition_time函数创建动态散点图动画(gif)
Stretchdibits function
Kubernetes three open interfaces first sight
鸿蒙第四次学习
性能测试如何创造业务价值