当前位置:网站首页>In pytorch function__ call__ And forward functions
In pytorch function__ call__ And forward functions
2022-07-02 19:11:00 【D.ziyu】
Beginners nn.Module, I can't understand all kinds of calls , Later I understood , I guess I'll forget , So write a note
init & call
Code :
class A():
def __init__(self):
print('init function ')
def __call__(self, param):
print('call function ', param)
a = A()
Output
analysis :A Instantiate the class , Generating objects a, This process Automatically call _init_(), There is no call _call_()
Add a line to the above code
class A():
def __init__(self):
print('init function ')
def __call__(self, param):
print('call function ', param)
a = A()
a(1)
Output
analysis :a It's the object ,python Let the object have parentheses like a function ( Parameters ) The function of , When using this function , Automatically call _call_()
_ call_() Other functions can be called in , Such as forward function
class A():
def __init__(self):
print('init function ')
def __call__(self, param):
print('call function ', param)
res = self.forward(param)
return res + 2
def forward(self, input_):
print('forward function ', input_)
return input_
a = A()
b = a(1)
print(' result b =',b)
analysis :_call _() Successfully called forward(), And the return value is given to b
Besides, I had a misunderstand , Think that the value of this class can only be used when the parameter is declared , This is wrong
class A():
def __init__(self):
print('init function ')
self.a = 100 # Declare parameters a
def __call__(self, param):
print('call function ', param)
res = self.forward(param)
return res + 2
def forward(self, input_):
print('forward function ', input_, self.a)
return input_
a = A()
b = a(1)
print(' result b =',b)
print(a.a)
nn.Module
Look at the example above , You know the _call _() The role of , Let's see more below CNN Example
from torch import nn
import torch
class Ding(nn.Module):
def __init__(self):
print('init')
super().__init__()
def forward(self, input):
output = input + 1
print("forward")
return output
dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)
result :
analysis :
There is no call here _call_() and forward(), But it still shows forward, as a result of :Ding This subclass inherits the parent class nn.Module Inside call function , Next, go to the source code
Find out _call_ Called _call_impl This function , It's equivalent to a nickname , Then go to this function to see
There are many parameters here , See reference for details 2. Found here forward_call Or _slow_forward, Or self.forward(), And this _slow_forward() Can also use self.forward()
therefore : _call _() It was used forward, And the parent class forward Override... In subclasses ( Simple code )
Of course , You can also rewrite __call__(), For example, we don't let it use forward()
from torch import nn
import torch
class Ding(nn.Module):
def __init__(self):
print('init')
super().__init__()
def __call__(self, input_):
print(' rewrite call, no need forward')
return 'hhh'
def forward(self, input):
output = input + 1
print("forward")
return output
dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)
summary
Use object dzy(x) when , Parent class used nn.Module Of call function , Called forward, And this forward It was rewritten by us in subclasses .
Reference resources
https://blog.csdn.net/dss_dssssd/article/details/83750838
https://zhuanlan.zhihu.com/p/366461413
边栏推荐
- "Patient's family, please come here" reading notes
- 开发固定资产管理系统,开发固定资产管理系统用什么语音
- R语言ggplot2可视化:可视化折线图、使用labs函数为折线图添加自定义的X轴标签信息
- 教程篇(5.0) 09. RESTful API * FortiEDR * Fortinet 网络安全专家 NSE 5
- Obligatoire pour les débutants, cliquez sur deux boutons pour passer à un contenu différent
- Introduction to the paper | application of machine learning in database cardinality estimation
- PyTorch函数中的__call__和forward函数
- The student Tiktok publicized that his alma mater was roast about "reducing the seal of enrollment". Netizen: hahahahahahahaha
- 消息队列消息丢失和消息重复发送的处理策略
- R language uses the lsnofunction function function of epidisplay package to list all objects in the current space, except user-defined function objects
猜你喜欢
思维意识转变是施工企业数字化转型成败的关键
新手必看,點擊兩個按鈕切換至不同的內容
[100 cases of JVM tuning practice] 03 -- four cases of JVM heap tuning
[论文阅读] CA-Net: Leveraging Contextual Features for Lung Cancer Prediction
yolov3 训练自己的数据集之生成train.txt
[daily question] the next day
Use cheat engine to modify money, life and stars in Kingdom rush
新加坡暑假旅遊攻略:一天玩轉新加坡聖淘沙島
Mini Golf Course: a good place for leisure and tourism in London
电商系统中常见的 9 大坑,你踩过没?
随机推荐
Reduce -- traverse element calculation. The specific calculation formula needs to be passed in and combined with BigDecimal
Masa framework - DDD design (1)
数据降维——因子分析
R language uses lrtest function of epidisplay package to perform likelihood ratio test on multiple GLM models (logisti regression). Compare whether the performance of the two models is different, and
深度学习数学基础
Competence of product manager
论文导读 | 关于将预训练语言模型作为知识库的分析与批评
R language ggplot2 visualization: gganimate package creates dynamic histogram animation (GIF) and uses transition_ The States function displays a histogram step by step along a given dimension in the
Have you stepped on the nine common pits in the e-commerce system?
Talk about the design of red envelope activities in e-commerce system
R语言使用epiDisplay包的lrtest函数对多个glm模型(logisti回归)执行似然比检验(Likelihood ratio test)对比两个模型的性能是否有差异、广义线性模型的似然比检
电商系统中常见的 9 大坑,你踩过没?
metric_logger小解
医院在线问诊源码 医院视频问诊源码 医院小程序源码
How performance testing creates business value
聊聊电商系统中红包活动设计
Golang concurrent programming goroutine, channel, sync
R语言ggplot2可视化分面图(facet):gganimate包基于transition_time函数创建动态散点图动画(gif)
2022 software engineering final exam recall Edition
MySQL高级(进阶)SQL语句