当前位置:网站首页>Four kinds of hooks in deep learning
Four kinds of hooks in deep learning
2022-06-13 08:51:00 【Human high quality Algorithm Engineer】
In order to save video memory ( Memory ),pytorch Do not save intermediate variables during calculation , Including the characteristic graph of the middle layer and the gradient of the non leaf tensor . Sometimes it is necessary to view or modify these intermediate variables when analyzing the network , You need to register a hook (hook) To export the required intermediate variables . There are many online introductions to this , But I looked around , There are many inaccuracies or incomprehensible places , Let me sum up here , Give the actual usage and notes .
hook There are four ways :
torch.Tensor.register_hook()
torch.nn.Module.register_forward_hook()
torch.nn.Module.register_backward_hook()
torch.nn.Module.register_forward_pre_hook().
The first one is torch.Tensor.register_hook()
import torch
def grad_hook(grad):
grad *= 1.6
x = torch.tensor([1., 1., 1., 1.], requires_grad=True)
y = torch.pow(x, 2)
z = torch.sum(y)
h = x.register_hook(grad_hook)
z.backward()
print(x.grad)
h.remove() # removes the hook
The result is :
tensor([3.2000, 3.2000, 3.2000, 3.2000])
import torch
def grad_hook(grad):
grad *= 50
x = torch.tensor([2., 2., 2., 2.], requires_grad=True)
y = torch.pow(x, 2)
z = torch.mean(y)
h = x.register_hook(grad_hook)
z.backward()
print(x.grad)
h.remove() # removes the hook
>>> tensor([50., 50., 50., 50.])
How is this value calculated , In fact, it is in the time of back propagation , obtain x Gradient of , The above code not only obtains x Gradient of , And multiply it by 1.6, Can change the value of the gradient .
notes :
It can be used remove() Method cancel hook. Be careful remove() Must be in backward() after , Because only in execution backward() When the sentence is ,pytorch Just started to calculate the gradient , And in the x.register_hook(grad_hook) When it's just " register " One. grad The hook , There is no calculation at this time , And perform remove Just cancel the hook , And then again backward() The hook doesn't work .
The second kind torch.nn.Module.register_forward_hook(module, in, out)
Used to export the specified sub module ( Can be a layer 、 Module etc. nn.Module type ) The input-output tensor of , But only the output can be modified , It is often used to derive or modify convolution characteristic graph .
inps, outs = [],[]
def layer_hook(module, inp, out):
inps.append(inp[0].data.cpu().numpy())
outs.append(out.data.cpu().numpy())
hook = net.layer1.register_forward_hook(layer_hook)
output = net(input)
hook.remove()
Be careful :(1) Because modules can be multi input , So the input is tuple Type , You need to extract the Tensor Then operate ; The output is Tensor Type can be used directly .
(2) Do not put it on the video memory after exporting , Unless you have A100.
(3) Only the output can be modified out Value , Cannot modify input inp Value ( Can't return , Local modifications are also invalid ), It is better to use return returns , Such as :
def layer_hook(self, module, inp, out):
out = self.lam * out + (1 - self.lam) * out[self.indices]
return out
This code is used in manifold mixup in , It is used to mix the features of the middle layer to achieve data enhancement , among self.lam It's a [0,1] Probability value ,self.indices yes shuffle The serial number after .
3, torch.nn.Module.register_forward_pre_hook(module, in)
Used to export or modify the input tensor of the specified sub module .
def pre_hook(module, inp):
inp0 = inp[0]
inp0 = inp0 * 2
inp = tuple([inp0])
return inp
hook = net.layer1.register_forward_pre_hook(pre_hook)
output = net(input)
hook.remove()
Be careful :(1)inp Value is a tuple type , So we need to extract the tensor first , Do something else , And then it has to be transformed into tuple return .
(2) In execution output = net(input) This sentence is called only when ,remove() It can be used to cancel the hook after the call .
4, torch.nn.Module.register_backward_hook(module, grad_in, grad_out)
Used to derive the gradient of the input-output tensor of the specified sub module , But only the gradient of the input tensor can be modified ( That is, it can only return gin), The output tensor gradient is not modifiable .
gouts = []
def backward_hook(module, gin, gout):
print(len(gin),len(gout))
gouts.append(gout[0].data.cpu().numpy())
gin0,gin1,gin2 = gin
gin1 = gin1*2
gin2 = gin2*3
gin = tuple([gin0,gin1,gin2])
return gin
hook = net.layer1.register_backward_hook(backward_hook)
loss.backward()
hook.remove()
Be careful :
(1) Among them grad_in and grad_out All are tuple, Must be untied first , When modifying, perform the operation and then put it back tuple return .
(2) This hook function is in backward() Statement is called , therefore remove() Put it on backward() Then it is used to cancel the hook .
边栏推荐
- 3、 JS notes
- 14. class initialization, default constructor, =default
- Centering problem - the width and height of child elements are known
- The 360 mobile assistant on Huawei maimang 7 cannot be uninstalled
- Namespace in TS (1)
- Web page H5 wechat sharing
- 【leetcode周赛记录】第80场双周赛记录
- WARNING:tornado.access:404 GET /favicon.ico (172.16.8.1) 1.84ms [附静态文件设置]
- 4. Relationship selector (parent-child relationship, ancestor offspring relationship, brother relationship)
- What is the difference between getfullyear() and getyear()
猜你喜欢
Screenshot of cesium implementation scenario
Svg text stroke effect
redis
CentOS installing MySQL and setting up remote access
Problems in the deconstruction and assignment of objects, comparison between empty strings and undefined
Uni app subcontracting loading and optimization
VI editor
2021-04-16
Replace jade engine with EJS
Notes on development experience: TP5 exp query, SQL analysis, JQ, applet, right-click menu, Linux skills, shell skills, mysql, etc
随机推荐
2021-04-16
d3.js&nvd3. JS - how to set the y-axis range - d3 js & nvd3. js — How to set y-axis range
JD commodity detail interface, JD detail page interface, baby detail page interface, commodity attribute interface, commodity information query, commodity detail interface, H5 details, JD app details,
Can I open an account for the reverse repurchase of treasury bonds? Can I directly open the security of securities companies on the app for the reverse repurchase of treasury bonds? How can I open an
Tmall product details interface, tmall product coupon interface, tmall API interface, tmall price monitoring interface, tmall price comparison interface, brand rights protection interface, tmall sales
redis
Gbase 8A disk problems and Solutions
Brief description of software testing and software maintenance
Explanation of JS event loop mechanism and asynchronous tasks
useRoutes() may be used only in the context of a <Router> component.
GBase 8a磁盘问题及处理
About RSA encryption and decryption principle
Margin:0 reason why auto does not take effect
15. copy constructor
Uni app subcontracting loading and optimization
Custom exception class myexception
Animation through svg
centos 安装mysql及设置远程访问
天猫商品详情接口,天猫商品优惠券接口,天猫api接口,天猫价格监控接口,天猫比价接口,品牌维权接口,天猫销量api接口,接口代码可对接数据分析业务,品牌维权,比价业务,行业分析业务接口代码分享
13.inline,const,mutable,this,static