当前位置:网站首页>PyTorch函数中的__call__和forward函数
PyTorch函数中的__call__和forward函数
2022-07-02 17:50:00 【D.ziyu】
初学nn.Module,看不懂各种调用,后来看明白了,估计会忘,故写篇笔记记录
init & call
代码:
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
输出
分析:A进行类的实例化,生成对象a,这个过程自动调用_init_(),没有调用_call_()
上面的代码加一行
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
a(1)
输出
分析:a是对象,python中让对象有了像函数一样加括号(参数)的功能,使用这种功能时,自动调用_call_()
_ call_()中可以调用其它函数,如forward函数
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
res = self.forward(param)
return res + 2
def forward(self, input_):
print('forward 函数', input_)
return input_
a = A()
b = a(1)
print('结果b =',b)

分析:_call _()成功调用了forward(),且返回值给了b
另外我之前有个误解,以为该类的值只有参数声明了才能用,这是错误的
class A():
def __init__(self):
print('init函数')
self.a = 100 # 声明参数a
def __call__(self, param):
print('call 函数', param)
res = self.forward(param)
return res + 2
def forward(self, input_):
print('forward 函数', input_, self.a)
return input_
a = A()
b = a(1)
print('结果b =',b)
print(a.a)

nn.Module
看了上面的例子,就知道了_call _()的作用,那下面看更CNN的例子
from torch import nn
import torch
class Ding(nn.Module):
def __init__(self):
print('init')
super().__init__()
def forward(self, input):
output = input + 1
print("forward")
return output
dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)
结果:
分析:
这里并没有调用_call_() 和forward(),但还是显示了forward,原因是:Ding这个子类继承了父类nn.Module里的call函数,接下来去源码看
发现_call_调用了_call_impl这个函数,相当于起了个外号一样,那就去这个函数看


这里有很多参数,详细可见参考2。发现这里forward_call 要么是_slow_forward,要么是self.forward(),而这个_slow_forward()也会用self.forward()
所以: _call _()用了forward,而这个父类的forward在子类中重写了(简单代码)
当然,也可以重写__call__(),比如我们不让它使用forward()
from torch import nn
import torch
class Ding(nn.Module):
def __init__(self):
print('init')
super().__init__()
def __call__(self, input_):
print('重写call, 不用forward')
return 'hhh'
def forward(self, input):
output = input + 1
print("forward")
return output
dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)

总结
使用对象dzy(x)时,用了父类nn.Module的call函数,调用了forward,而这个forward又被我们在子类里重写了。
参考
https://blog.csdn.net/dss_dssssd/article/details/83750838
https://zhuanlan.zhihu.com/p/366461413
边栏推荐
- Redis (7) -- database and expiration key
- Introduction to the paper | application of machine learning in database cardinality estimation
- What is cloud primordial? This time, I can finally understand!
- STM32G0 USB DFU 升级校验出错-2
- IPtable port redirection masquerade[easy to understand]
- How to play when you travel to Bangkok for the first time? Please keep this money saving strategy
- Websocket of Web real-time communication technology
- 【测试开发】软件测试—概念篇
- R language ggplot2 visualization: gganimate package creates dynamic histogram animation (GIF) and uses transition_ The States function displays a histogram step by step along a given dimension in the
- 论文导读 | 关于将预训练语言模型作为知识库的分析与批评
猜你喜欢

Web version 3D visualization tool, 97 things programmers should know, AI frontier paper | information daily # 2022.07.01

Installation of thingsboard, an open source IOT platform

Kubernetes three open interfaces first sight

ICDE 2023|TKDE Poster Session(CFP)

新手必看,点击两个按钮切换至不同的内容

新加坡暑假旅遊攻略:一天玩轉新加坡聖淘沙島
![[100 cases of JVM tuning practice] 03 -- four cases of JVM heap tuning](/img/54/8a18cd30e6186528599c0556b1ee3b.png)
[100 cases of JVM tuning practice] 03 -- four cases of JVM heap tuning

第一次去曼谷旅游怎么玩?这份省钱攻略请收好

文字编辑器 希望有错误的句子用红色标红,文字编辑器用了markdown

Introduction to the paper | application of machine learning in database cardinality estimation
随机推荐
The R language dplyr package rowwise function and mutate function calculate the maximum value of multiple data columns in each row in the dataframe data, and generate the data column (row maximum) cor
Use MNIST in tensorflow 2_ 784 data set for handwritten digit recognition
ORA-01455: converting column overflows integer datatype
2022软件工程期末考试 回忆版
LightGroupButton* sender = static_ cast<LightGroupButton*>(QObject::sender());
Chain game system development (unity3d chain game development details) - chain game development mature technology source code
Golang并发编程——goroutine、channel、sync
CDN acceleration and breaking J anti-theft chain function
FastDFS安装
UML 类图
The student Tiktok publicized that his alma mater was roast about "reducing the seal of enrollment". Netizen: hahahahahahahaha
Tips for material UV masking
使用 Cheat Engine 修改 Kingdom Rush 中的金钱、生命、星
[daily question] the next day
Learning summary of MySQL advanced 6: concept and understanding of index, detailed explanation of b+ tree generation process, comparison between MyISAM and InnoDB
After 22 years in office, the father of PowerShell will leave Microsoft: he was demoted by Microsoft for developing PowerShell
Meta universe chain game system development (logic development) - chain game system development (detailed analysis)
《病人家属,请来一下》读书笔记
使用CLion编译OGLPG-9th-Edition源码
Thoroughly understand the point cloud processing tutorial based on open3d!