当前位置:网站首页>学习记录4: einops // cudnn.benchamark=true // hook
学习记录4: einops // cudnn.benchamark=true // hook
2022-06-13 08:16:00 【zyr_freedom】
einops
import torch
from einops import rearrange,reduce,repeat
x= torch.randn(2,3,8,8)
#1 转置操作
out1 = x.transpose(1,2)
out2 = rearrange(x,'b c h w ->b h c w')
print('verify out1 & out2 ---->:',torch.allclose(out1,out2))
#2 变形
out3 = x.reshape(6,8,8)
out4 = rearrange(x,'b c h w -> (b c) h w')
x_restore = rearrange(out4,'(b c) h w -> b c h w ',b =2)
print('verify out3 & out4 ---->:',torch.allclose(out3,out4),'|| verify x & x_restore---->:',torch.allclose(x,x_restore))
#3 image2patch
out5= rearrange(x,'b c (h1 p1) (w1 p2) -> b c (h1 w1) (p1 p2)',p1=2,p2=2) # 这个得到的patch 是non-overlapping的
print('out5 ---->:',x.size(),out5.size())
out6 = rearrange(out5,'b c n a -> b n (a c)') # [batchsize, num_of_patches, patches_depth]
print('out6 ---->:',out6.size())
#4 求平均池化
out7 = reduce(x,'b c h w -> b c','mean')
print('out7---->:',out7.size())
#5 堆叠tensor
x_list = [x,x,x]
out8 = rearrange(x_list,'n b c h w -> n b c h w ')
print('out8---->:',out8.size())
#6 扩维
out9 = rearrange(x,'b c h w -> b c h w 1 ') #类似于 torch.unsqueeze
#
print('out9---->:',out9.size())
#7 复制
out10 = repeat(out9,'b c h w 1 -> b c h w 2 ') #类似于 torch.tile
out11 = repeat(x,'b c h w -> b (2 c) h w ') #沿着通道复制
print('out10---->:',out10.size(),'out11---->:',out11.size())
打印的结果如下:

torch.backends.cudnn.benchmark=True
设置 torch.backends.cudnn.benchmark=True 将会让程序在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定(不是动态变化的),网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间 .
其实一般加在开头就好,比如在设置使用 GPU 的同时,后边补一句:
if args.use_gpu and torch.cuda.is_available():
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True
else:
device = torch.device('cpu')对于何时适合设置 torch.backends.cudnn.benchmark=True,一句话就是:如果卷积网络结构不是动态变化的,网络的输入 (batch size,图像的大小,输入的通道) 是固定的,那么就放心用。
hook 获取中间特征层
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.conv3 = nn.Conv2d(16, 32, 5)
def forward(self, x):
out = self.conv1(x)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out = F.relu(out)
out = self.conv3(out)
out = F.max_pool2d(out, 2)
return out
#第一种是修改网络结构,通过网络return 返回想要的变量
class LeNet_multi_outputs(nn.Module):
def __init__(self):
super(LeNet_multi_outputs, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.conv3 = nn.Conv2d(16, 32, 5)
def forward(self, x):
out = self.conv1(x)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out_conv2 = out
out = F.relu(out)
out = self.conv3(out)
out = F.max_pool2d(out, 2)
return out,out_conv2
if __name__ == "__main__":
model = LeNet_multi_outputs()
print(model)
print(model.conv1)
input = torch.randn(1, 3, 224, 224)
output,out_conv2 = model(input)
print('out_conv2:',out_conv2.size())
features = []
# hook函数需要三个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数
def hook(module, input, output):
features.append(output.clone().detach())
net = model #LeNet()
x = input
# 取出网络的相应层后,对该层调用register_forward_hook方法。这个方法需要传入一个hook方法:
handle = net.conv2.register_forward_hook(hook)
#从这里可以发现hook甚至可以更改输入输出(不过并不会影响网络forward的实际结果),不过在这里我们只是简单地将output给保存下来。
# 需要注意的是hook函数在使用后应及时删除,以避免每次都运行增加运行负载。
y,_ = net(x)
print('net.conv2.register_forward_hook:',features[0].size())
print('validation:',torch.allclose(features[0], out_conv2) )
handle.remove()边栏推荐
- Cosmos star application case
- Daffodil upgrade (self idempotent)
- 字符串的逆序与比较
- 【博弈论-完全信息静态博弈】 Nash均衡的应用
- 中小型照明灯饰行业如何利用数字化转型突出重围?
- 22 | adventure and prediction (I): hazard is both "danger" and "opportunity"
- Founder of Starbucks: no longer open "public toilets" to non store consumers for safety reasons
- Free file server storage technology
- 星巴克创始人:出于安全考量 或不再向非店内消费者开放“公厕”
- AcWing 1977. 信息中继(基环树,并查集)
猜你喜欢
Import the robot model built by SolidWorks into ROS
![[game theory complete information static game] Nash equilibrium](/img/db/9923f5a7465c8b57182f09810b65bf.jpg)
[game theory complete information static game] Nash equilibrium

获取类的属性

The method of SolidWorks modifying text font in engineering drawing

Did decentralized digital identity

2022起重机械指挥考试题模拟考试题库及在线模拟考试

【PYTORCH】Expected object of type torch. xxxTensor but found type torch. cuda. xxxTensor(torch0.4.0)
![[problem record] taberror: inconsistent use of tabs and spaces in indentation](/img/dd/5ba456ac4201c8330d16f4b3bed81d.jpg)
[problem record] taberror: inconsistent use of tabs and spaces in indentation

CCNP_ BT-MGRE

How to download and install stm32cubemx
随机推荐
How to efficiently manage commodities and inventory in the beverage wholesale industry
杨氏矩阵查找数字是否存在
钉钉小程序 父子传参数对象 子组件页面不更新?
口碑好的食材配送信息化管理系统怎么样?
Start from scratch - implement the jpetstore website -1- establish the project framework and project introduction
[problem record] taberror: inconsistent use of tabs and spaces in indentation
Plane merging (matlab)
Founder of Starbucks: no longer open "public toilets" to non store consumers for safety reasons
Sizeof, strlen find character length
ERP基础数据 金蝶
第115页的gtk+编程例子——最简单的进度条2附带使用Anjuta写gtk程序的步骤
MySQL interview questions
Microservice system architecture construction I: Environment Construction
ES6 deleting an attribute of an object
Maternal and infant supplies wholesale industry uses management software to improve efficiency and realize cost reduction and efficiency increase
Idea shortcut summary
将solidworks建的机器人模型导入到ros中
Redis interview questions
关于redis使用分布式锁的封装工具类
Rust writes near smart contract