当前位置:网站首页>获取预训练模型的网络输入尺寸
获取预训练模型的网络输入尺寸
2022-08-05 05:25:00 【ProfSnail】
初学神经网络之时,常常会用到预训练的网络包。
例如
from torchvision import models
resnet = models.res18(pretrained=True)
但是初学者在使用的时候会犯难:我需要输入多大尺寸的图片呢?
解决方案:
方法一:读torchvision.models的说明文档
打开torchvision.models的网站:
https://pytorch.org/hub/research-models
搜索你需要的模型名称,得到resnet的网站:
https://pytorch.org/hub/pytorch_vision_resnet/
在里面看到resnet的指导文档:
import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
# or any of these variants
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', pretrained=True)
model.eval()
All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
Here’s a sample execution.
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
# move the input and model to GPU for speed if available
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
with torch.no_grad():
output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)
从中,可以了解到,需要将图片缩放到尺寸为[224, 224]。
方法二:读GitHub源代码
除了看torchvision的指导文档,也可以进入github看源码中的注释(View on GitHub):
Github代码中包括如下内容:
class ResNet18_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 11689512,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"_metrics": {
"ImageNet-1K": {
"[email protected]": 69.758,
"[email protected]": 89.078,
}
},
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
由其中的crop_size=224了解到应该缩放或者裁减为224大小。
方法三:暴力测试
尺寸设计为自己觉得可行的大小,在一定的区间内进行循环,如果尺寸不合适,Pytorch会报错说模型尺寸不相容。用except跳过这些异常,保留try成功的尺寸。
transfer_model = resnet18(pretrained=True)
print(transfer_model)
transfer_model.eval()
transfer_model = transfer_model.cuda()
batch_size = 16
ok_list = []
for length in range(1, 1000):
x = torch.zeros(batch_size, 3, length, length, requires_grad=False).cuda()
print('length = {}'.format(length))
try:
torch_out = transfer_model(x)
print('length = {} OK!!!!'.format(length))
ok_list.append(length)
# break
except:
continue
print(ok_list)
硬试,试试输入尺寸是多少的时候,这个网络模型能跑通,跑的通再逐步测试。注意,由于神经网络里面有一些下采样的取整操作,所以尺寸在某些区间范围内,比如刚刚这段代码输出的结果是:
...
length = 998
length = 999
[193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224]
尺寸224只是可以接受的大小其中之一,也是设计者的初衷尺寸。
边栏推荐
- Media query, rem mobile terminal adaptation
- The size of the screen adaptation
- VRRP overview and experiment
- 人人AI(吴恩达系列)
- Shadowless Cloud Desktop
- Configuration of routers and static routes
- docker部署完mysql无法连接
- Quick question and quick answer - FAQ of Tencent Cloud Server
- DevOps - Understanding Learning
- RAID disk array
猜你喜欢
单臂路由实验和三层交换机实验
Alibaba Cloud Video on Demand
The hook of the operation of the selenium module
Vim tutorial: vimtutor
The cocos interview answers you are looking for are all here!
[问题已处理]-虚拟机报错contains a file system with errors check forced
NACOS配置中心设置配置文件
King power volume LinkSLA, realize operations engineer is happy fishing
BIO,NIO,AIO实践学习笔记(便于理解理论)
The 25 best free games on mobile in 2020
随机推荐
Transformer interprets and predicts instance records in detail
BIO, NIO, AIO practical study notes (easy to understand theory)
Passing parameters in multiple threads
Operation and maintenance engineer, come and pick up the wool
浏览器兼容汇总
Proprietary host CDH
The future of cloud gaming
干货!教您使用工业树莓派结合CODESYS配置EtherCAT主站
Into the pre-service, thought they play so flowers
错误记录集锦(遇到则记下)
VLAN is introduced with the experiment
请问下通过flink sql读取hologres 的两张表的 binlog,然后如何进行join?
Autoware--Beike Tianhui rfans lidar uses the camera & lidar joint calibration file to verify the fusion effect of point cloud images
网络协议基础-学习笔记
程序员应该这样理解I/O
Nacos集群搭建
Detailed explanation of ten solutions across domains (summary)
D39_ coordinate transformation
Network Troubleshooting Basics - Study Notes
记录vue-页面缓存问题