当前位置:网站首页>学习记录4: einops // cudnn.benchamark=true // hook
学习记录4: einops // cudnn.benchamark=true // hook
2022-06-13 08:16:00 【zyr_freedom】
einops
import torch
from einops import rearrange,reduce,repeat
x= torch.randn(2,3,8,8)
#1 转置操作
out1 = x.transpose(1,2)
out2 = rearrange(x,'b c h w ->b h c w')
print('verify out1 & out2 ---->:',torch.allclose(out1,out2))
#2 变形
out3 = x.reshape(6,8,8)
out4 = rearrange(x,'b c h w -> (b c) h w')
x_restore = rearrange(out4,'(b c) h w -> b c h w ',b =2)
print('verify out3 & out4 ---->:',torch.allclose(out3,out4),'|| verify x & x_restore---->:',torch.allclose(x,x_restore))
#3 image2patch
out5= rearrange(x,'b c (h1 p1) (w1 p2) -> b c (h1 w1) (p1 p2)',p1=2,p2=2) # 这个得到的patch 是non-overlapping的
print('out5 ---->:',x.size(),out5.size())
out6 = rearrange(out5,'b c n a -> b n (a c)') # [batchsize, num_of_patches, patches_depth]
print('out6 ---->:',out6.size())
#4 求平均池化
out7 = reduce(x,'b c h w -> b c','mean')
print('out7---->:',out7.size())
#5 堆叠tensor
x_list = [x,x,x]
out8 = rearrange(x_list,'n b c h w -> n b c h w ')
print('out8---->:',out8.size())
#6 扩维
out9 = rearrange(x,'b c h w -> b c h w 1 ') #类似于 torch.unsqueeze
#
print('out9---->:',out9.size())
#7 复制
out10 = repeat(out9,'b c h w 1 -> b c h w 2 ') #类似于 torch.tile
out11 = repeat(x,'b c h w -> b (2 c) h w ') #沿着通道复制
print('out10---->:',out10.size(),'out11---->:',out11.size())
打印的结果如下:

torch.backends.cudnn.benchmark=True
设置 torch.backends.cudnn.benchmark=True 将会让程序在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定(不是动态变化的),网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间 .
其实一般加在开头就好,比如在设置使用 GPU 的同时,后边补一句:
if args.use_gpu and torch.cuda.is_available():
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True
else:
device = torch.device('cpu')对于何时适合设置 torch.backends.cudnn.benchmark=True,一句话就是:如果卷积网络结构不是动态变化的,网络的输入 (batch size,图像的大小,输入的通道) 是固定的,那么就放心用。
hook 获取中间特征层
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.conv3 = nn.Conv2d(16, 32, 5)
def forward(self, x):
out = self.conv1(x)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out = F.relu(out)
out = self.conv3(out)
out = F.max_pool2d(out, 2)
return out
#第一种是修改网络结构,通过网络return 返回想要的变量
class LeNet_multi_outputs(nn.Module):
def __init__(self):
super(LeNet_multi_outputs, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.conv3 = nn.Conv2d(16, 32, 5)
def forward(self, x):
out = self.conv1(x)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out_conv2 = out
out = F.relu(out)
out = self.conv3(out)
out = F.max_pool2d(out, 2)
return out,out_conv2
if __name__ == "__main__":
model = LeNet_multi_outputs()
print(model)
print(model.conv1)
input = torch.randn(1, 3, 224, 224)
output,out_conv2 = model(input)
print('out_conv2:',out_conv2.size())
features = []
# hook函数需要三个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数
def hook(module, input, output):
features.append(output.clone().detach())
net = model #LeNet()
x = input
# 取出网络的相应层后,对该层调用register_forward_hook方法。这个方法需要传入一个hook方法:
handle = net.conv2.register_forward_hook(hook)
#从这里可以发现hook甚至可以更改输入输出(不过并不会影响网络forward的实际结果),不过在这里我们只是简单地将output给保存下来。
# 需要注意的是hook函数在使用后应及时删除,以避免每次都运行增加运行负载。
y,_ = net(x)
print('net.conv2.register_forward_hook:',features[0].size())
print('validation:',torch.allclose(features[0], out_conv2) )
handle.remove()边栏推荐
- Microservice system architecture construction I: Environment Construction
- Why do wholesalers use the order system
- ERP 基础数据 概念
- 赋予代码生命力--读代码整洁之道
- BD新标签页(BdTab)插件如何登入?
- Introduction to dfinity (ICP) -1
- Disk C is full? A few simple tips teach you to release and clean up tens of gigabytes of space on the C disk, the most effective way to clean up the C disk
- 酒水批发行业应当如何高效管理商品与库存
- Common shell script development specifications
- Sizeof, strlen find character length
猜你喜欢

How to hide tabs in nailing applet

How to download and install stm32cubemx

2022 simulated examination question bank and online simulated examination of hoisting machinery command examination questions

20 | pipeline oriented instruction design (Part 1): Modern CPU with multi-purpose

实践出真知--你的字节对齐和堆栈认知可能是错误的

Several precise order receiving methods suitable for fresh food wholesale industry

2022 electrician (elementary) examination questions and simulation examination

中小型照明灯饰行业如何利用数字化转型突出重围?

Introduction to dfinity (ICP) -1

How to modify desktop path in win10 system
随机推荐
Redis interview questions
Data disorder occurs when the n-th row of the subcomponent list generated by V-for is deleted
基于paddlepaddle的新冠肺炎识别
【完全信息静态博弈-Nash均衡的特性】
mysql面试题
【PYTORCH】RuntimeError: torch. cuda. FloatTensor is not enabled.
How to use annotations in word
Edge browser uses bdtab new tab plug-in (BD new tab)
酒水批发行业应当如何高效管理商品与库存
Effective Go - The Go Programming Language
18 | establish data path (middle): instruction + operation =cpu
How to dynamically delete data rows in a table through JS (keep the head)
Differences between Merkle DAG and Merkle tree
Create a substrate private network
批发商为什么要使用订单系统
Several precise order receiving methods suitable for fresh food wholesale industry
微服务项目搭建三:自动生成代码
How to install the bdtab (BD) new tab plug-in in edge browser (Graphic tutorial)
2022 simulated examination question bank and online simulated examination of hoisting machinery command examination questions
2022起重机械指挥考试题模拟考试题库及在线模拟考试