当前位置:网站首页>def fasterrcnn_resnet50_fpn()实例测试
def fasterrcnn_resnet50_fpn()实例测试
2022-07-29 04:38:00 【Eden_mm】
训练阶段
import torch
import torchvision
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
# for train
images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
labels = torch.randint(1, 91, (4, 11))
images = list(image for image in images)
targets = []
for i in range(len(images)): # i = 0, 1, 2, 3
d = {
}
d['boxes'] = boxes[i]
d['labels'] = labels[i]
targets.append(d)
output = model(images, targets)
print(output)
#
# # for inference
# model.eval()
#
# x = [torch.rand(3, 300, 400), torch.rand(3, 300, 400)]
# predictions = model(x)
结果:
{
'loss_classifier': tensor(0.3633, grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(0.0262, grad_fn=<DivBackward0>), 'loss_objectness': tensor(1.9085, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_rpn_box_reg': tensor(1.0729, grad_fn=<DivBackward0>)}
推理阶段
import torch
import torchvision
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
# # for train
# images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
# boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
# labels = torch.randint(1, 91, (4, 11))
#
# images = list(image for image in images)
# targets = []
# for i in range(len(images)): # i = 0, 1, 2, 3
# d = {}
# d['boxes'] = boxes[i]
# d['labels'] = labels[i]
# targets.append(d)
#
# output = model(images, targets)
# print(output)
#
# for inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 300, 400)]
predictions = model(x)
print(predictions)
结果:
推测没有预测框的原因是数据是随机生成的
[{
'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward0>)},
{
'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward0>)}]
边栏推荐
- New year's greetings from programmers
- 央企建筑企业数字化转型核心特征是什么?
- No, just stick to it for 59 days
- Classes and objects (III)
- Record the Niua packaging deployment project
- Deep analysis of data storage in memory (Advanced C language)
- String, array, generalized table (detailed)
- Update learning materials daily
- 用 ZEGO Avatar 做一个虚拟人|虚拟主播直播解决方案
- VScode 一键编译和调试
猜你喜欢
随机推荐
Mongo Shell交互式命令窗口
[c language] PTA 7-48 find the number of combinations
C language: enumerating knowledge points summary
On quotation
Mysql:The user specified as a definer (‘root‘@‘%‘) does not exist 的解决办法
Down sampling and up sampling
Classes and objects (II)
Introduction and examples of parameters in Jenkins parametric construction
【Express连接MySQL数据库】
Classes and objects (I)
Flutter实战-请求封装(二)之dio
Make a virtual human with zego avatar | virtual anchor live broadcast solution
Back propagation process of manual BP neural network
Not for 63 days. The biggest XOR
OpenCV环境搭建
Install the gym corresponding to mujoco in the spinning up tutorial, and the error mjpro150 is reported
Mongo shell interactive command window
Pytorch fixed random seed & recurrence model
oracle 更新和删除数据
C language: typedef knowledge points summary









