当前位置:网站首页>获取预训练模型的网络输入尺寸
获取预训练模型的网络输入尺寸
2022-08-05 05:25:00 【ProfSnail】
初学神经网络之时,常常会用到预训练的网络包。
例如
from torchvision import models
resnet = models.res18(pretrained=True)
但是初学者在使用的时候会犯难:我需要输入多大尺寸的图片呢?
解决方案:
方法一:读torchvision.models的说明文档
打开torchvision.models的网站:
https://pytorch.org/hub/research-models
搜索你需要的模型名称,得到resnet的网站:
https://pytorch.org/hub/pytorch_vision_resnet/
在里面看到resnet的指导文档:
import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
# or any of these variants
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', pretrained=True)
model.eval()
All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
Here’s a sample execution.
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
# move the input and model to GPU for speed if available
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
with torch.no_grad():
output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)
从中,可以了解到,需要将图片缩放到尺寸为[224, 224]。
方法二:读GitHub源代码
除了看torchvision的指导文档,也可以进入github看源码中的注释(View on GitHub):
Github代码中包括如下内容:
class ResNet18_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 11689512,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"_metrics": {
"ImageNet-1K": {
"[email protected]": 69.758,
"[email protected]": 89.078,
}
},
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
由其中的crop_size=224了解到应该缩放或者裁减为224大小。

方法三:暴力测试
尺寸设计为自己觉得可行的大小,在一定的区间内进行循环,如果尺寸不合适,Pytorch会报错说模型尺寸不相容。用except跳过这些异常,保留try成功的尺寸。
transfer_model = resnet18(pretrained=True)
print(transfer_model)
transfer_model.eval()
transfer_model = transfer_model.cuda()
batch_size = 16
ok_list = []
for length in range(1, 1000):
x = torch.zeros(batch_size, 3, length, length, requires_grad=False).cuda()
print('length = {}'.format(length))
try:
torch_out = transfer_model(x)
print('length = {} OK!!!!'.format(length))
ok_list.append(length)
# break
except:
continue
print(ok_list)
硬试,试试输入尺寸是多少的时候,这个网络模型能跑通,跑的通再逐步测试。注意,由于神经网络里面有一些下采样的取整操作,所以尺寸在某些区间范围内,比如刚刚这段代码输出的结果是:
...
length = 998
length = 999
[193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224]
尺寸224只是可以接受的大小其中之一,也是设计者的初衷尺寸。
边栏推荐
猜你喜欢

BIO,NIO,AIO实践学习笔记(便于理解理论)

Into the pre-service, thought they play so flowers

What's the point of monitoring the involution of the system?

RAID disk array

Chengyun Technology was invited to attend the 2022 Alibaba Cloud Partner Conference and won the "Gathering Strength and Going Far" Award

By solving these three problems, the operation and maintenance efficiency will exceed 90% of the hospital

VLAN介绍与实验

docker部署完mysql无法连接

The cocos interview answers you are looking for are all here!

The 25 best free games on mobile in 2020
随机推荐
D46_Force applied to rigid body
错误类型:反射。ReflectionException:无法设置属性“xxx”的“类”xxx”与价值“xxx”
多线程之传递参数
网络协议基础-学习笔记
NB-IOT智能云家具项目系列实站
The size of the screen adaptation
[Problem has been resolved]-Virtual machine error contains a file system with errors check forced
Next-Generation Parsing Technology - Cloud Parsing
ROS2下使用ROS1 bag的方法
请问下通过flink sql读取hologres 的两张表的 binlog,然后如何进行join?
Media query, rem mobile terminal adaptation
单片机期末复习大题
VLAN is introduced with the experiment
selenium模块的操作之拉钩
Tencent Cloud Message Queue CMQ
DevOps - Understanding Learning
干货!教您使用工业树莓派结合CODESYS配置EtherCAT主站
LeetCode刷题记录(2)
DevOps-了解学习
带你深入了解Cookie