当前位置:网站首页>获取预训练模型的网络输入尺寸
获取预训练模型的网络输入尺寸
2022-08-05 05:25:00 【ProfSnail】
初学神经网络之时,常常会用到预训练的网络包。
例如
from torchvision import models
resnet = models.res18(pretrained=True)
但是初学者在使用的时候会犯难:我需要输入多大尺寸的图片呢?
解决方案:
方法一:读torchvision.models的说明文档
打开torchvision.models的网站:
https://pytorch.org/hub/research-models
搜索你需要的模型名称,得到resnet的网站:
https://pytorch.org/hub/pytorch_vision_resnet/
在里面看到resnet的指导文档:
import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
# or any of these variants
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', pretrained=True)
model.eval()
All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
Here’s a sample execution.
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
# move the input and model to GPU for speed if available
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
with torch.no_grad():
output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)
从中,可以了解到,需要将图片缩放到尺寸为[224, 224]。
方法二:读GitHub源代码
除了看torchvision的指导文档,也可以进入github看源码中的注释(View on GitHub):
Github代码中包括如下内容:
class ResNet18_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 11689512,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"_metrics": {
"ImageNet-1K": {
"[email protected]": 69.758,
"[email protected]": 89.078,
}
},
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
由其中的crop_size=224了解到应该缩放或者裁减为224大小。

方法三:暴力测试
尺寸设计为自己觉得可行的大小,在一定的区间内进行循环,如果尺寸不合适,Pytorch会报错说模型尺寸不相容。用except跳过这些异常,保留try成功的尺寸。
transfer_model = resnet18(pretrained=True)
print(transfer_model)
transfer_model.eval()
transfer_model = transfer_model.cuda()
batch_size = 16
ok_list = []
for length in range(1, 1000):
x = torch.zeros(batch_size, 3, length, length, requires_grad=False).cuda()
print('length = {}'.format(length))
try:
torch_out = transfer_model(x)
print('length = {} OK!!!!'.format(length))
ok_list.append(length)
# break
except:
continue
print(ok_list)
硬试,试试输入尺寸是多少的时候,这个网络模型能跑通,跑的通再逐步测试。注意,由于神经网络里面有一些下采样的取整操作,所以尺寸在某些区间范围内,比如刚刚这段代码输出的结果是:
...
length = 998
length = 999
[193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224]
尺寸224只是可以接受的大小其中之一,也是设计者的初衷尺寸。
边栏推荐
- What's the point of monitoring the involution of the system?
- [Problem has been resolved]-Virtual machine error contains a file system with errors check forced
- DevOps-了解学习
- What?CDN cache acceleration only works for accelerating static content?
- The use of three parameters of ref, out, and Params in Unity3D
- 请问下通过flink sql读取hologres 的两张表的 binlog,然后如何进行join?
- Next-Generation Parsing Technology - Cloud Parsing
- Nacos集群搭建
- docker部署完mysql无法连接
- Dry!Teach you to use industrial raspberries pie combining CODESYS configuration EtherCAT master station
猜你喜欢
随机推荐
The future of cloud gaming
Browser Storage for H5
What is Alibaba Cloud Express Beauty Station?
单片机原理与应用复习
Native JS takes you to understand the implementation and use of array methods
js 使用雪花id生成随机id
flink cdc 目前支持Gauss数据库源吗
DevOps - Understanding Learning
D46_Force applied to rigid body
浏览器存储WebStorage
product learning materials
数组&的运算
Growth: IT Operations Trends Report
跨域的十种解决方案详解(总结)
docker部署完mysql无法连接
VLAN介绍与实验
Difference between link and @improt
config.js相关配置汇总
el-autocomplete使用
Next-Generation Parsing Technology - Cloud Parsing









