当前位置:网站首页>Get the network input dimensions of the pretrained model
Get the network input dimensions of the pretrained model
2022-08-05 06:45:00 【ProfSnail】
When learning about neural networks,Pre-trained network packages are often used.
例如
from torchvision import models
resnet = models.res18(pretrained=True)
But when we are beginners in the use of puzzled:What size image do I need to enter??
解决方案:
方法一:读torchvision.models的说明文档
打开torchvision.models的网站:
https://pytorch.org/hub/research-models
Search for the model name you need,得到resnet的网站:
https://pytorch.org/hub/pytorch_vision_resnet/
在里面看到resnet的指导文档:
import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
# or any of these variants
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', pretrained=True)
model.eval()
All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
Here’s a sample execution.
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
# move the input and model to GPU for speed if available
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
with torch.no_grad():
output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)
从中,可以了解到,The image needs to be scaled to size[224, 224].
方法二:读GitHub源代码
除了看torchvision的指导文档,也可以进入githubsee the comments in the source code(View on GitHub):
GithubThe code includes the following:
class ResNet18_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 11689512,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"_metrics": {
"ImageNet-1K": {
"[email protected]": 69.758,
"[email protected]": 89.078,
}
},
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
由其中的crop_size=224Know that it should be scaled or cropped to224大小.
方法三:暴力测试
The size of the size design for yourself think feasible,loop in a certain interval,如果尺寸不合适,PytorchIt will report an error that the model size is not compatible.用exceptSkip these exceptions,保留tryThe size of the success.
transfer_model = resnet18(pretrained=True)
print(transfer_model)
transfer_model.eval()
transfer_model = transfer_model.cuda()
batch_size = 16
ok_list = []
for length in range(1, 1000):
x = torch.zeros(batch_size, 3, length, length, requires_grad=False).cuda()
print('length = {}'.format(length))
try:
torch_out = transfer_model(x)
print('length = {} OK!!!!'.format(length))
ok_list.append(length)
# break
except:
continue
print(ok_list)
hard test,When trying to enter the size,This network model can run through,Run through and test step by step.注意,Since there are some downsampling rounding operations in the neural network,So the size is in some range,For example, the output of the code just now is:
...
length = 998
length = 999
[193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224]
尺寸224Just one of the acceptable sizes,It is also the original size of the designer.
边栏推荐
- System basics - study notes (some command records)
- 系统基础-学习笔记(一些命令记录)
- Drools规则引擎快速入门(一)
- What is Alibaba Cloud Express Beauty Station?
- HelloWorld
- selenium学习
- 淘宝客APP带自营商城本地生活CPS外卖优惠电影票话费更新渠道跟单生活特权V3
- reduce()方法的学习和整理
- UI刘海屏适配方式
- Chengyun Technology was invited to attend the 2022 Alibaba Cloud Partner Conference and won the "Gathering Strength and Going Far" Award
猜你喜欢
LeetCode刷题记录(2)
The hook of the operation of the selenium module
Some basic method records of commonly used languages in LeetCode
Tencent Internal Technology: Evolution of Server Architecture of "The Legend of Xuanyuan"
Cloud Computing Basics - Study Notes
sql server duplicate values are counted after
Teach you simple steps to achieve industrial raspberries pie properly installed RS232 USB drive
[issue resolved] - jenkins pipeline checkout timeout
Dry!Teach you to use industrial raspberries pie combining CODESYS configuration EtherCAT master station
ev加密视频转换成MP4格式,亲测可用
随机推荐
前置++和后置++的区别
link 和@improt的区别
自营商城提高用户留存小技巧,商城对接小游戏分享
selenium模块的操作之拉钩
Native JS takes you to understand the implementation and use of array methods
Successful indie developers deal with failure & imposters
Difference between link and @improt
超简单的白鹭egret项目添加图片详细教程
人人AI(吴恩达系列)
Collection of error records (write down when you encounter them)
numpy.random usage documentation
Transformer详细解读与预测实例记录
The future of cloud gaming
格式化代码缩进的小技巧
邮件管理 过滤邮件
config.js related configuration summary
Cocos Creator Mini Game Case "Stick Soldier"
[ingress]-ingress exposes services using tcp port
媒体查询、rem移动端适配
js 使用雪花id生成随机id