当前位置:网站首页>def fasterrcnn_resnet50_fpn()实例测试
def fasterrcnn_resnet50_fpn()实例测试
2022-07-29 04:38:00 【Eden_mm】
训练阶段
import torch
import torchvision
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
# for train
images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
labels = torch.randint(1, 91, (4, 11))
images = list(image for image in images)
targets = []
for i in range(len(images)): # i = 0, 1, 2, 3
d = {
}
d['boxes'] = boxes[i]
d['labels'] = labels[i]
targets.append(d)
output = model(images, targets)
print(output)
#
# # for inference
# model.eval()
#
# x = [torch.rand(3, 300, 400), torch.rand(3, 300, 400)]
# predictions = model(x)
结果:
{
'loss_classifier': tensor(0.3633, grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(0.0262, grad_fn=<DivBackward0>), 'loss_objectness': tensor(1.9085, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_rpn_box_reg': tensor(1.0729, grad_fn=<DivBackward0>)}
推理阶段
import torch
import torchvision
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
# # for train
# images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
# boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
# labels = torch.randint(1, 91, (4, 11))
#
# images = list(image for image in images)
# targets = []
# for i in range(len(images)): # i = 0, 1, 2, 3
# d = {}
# d['boxes'] = boxes[i]
# d['labels'] = labels[i]
# targets.append(d)
#
# output = model(images, targets)
# print(output)
#
# for inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 300, 400)]
predictions = model(x)
print(predictions)
结果:
推测没有预测框的原因是数据是随机生成的
[{
'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward0>)},
{
'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward0>)}]
边栏推荐
- Pyqt5 learning pit encounter and pit drainage (1) unable to open designer.exe
- Exception handling: pyemd or pyemd not found
- 谷歌浏览器 打开网页出现 out of memory
- The third ACM program design competition of Wuhan University of Engineering
- [c language] PTA 7-51 sum the first n terms of odd part sequence
- Shell string segmentation
- redux快速上手
- On quotation
- Christmas tree web page and Christmas tree application
- Nail dialog text converted to pictures cannot be copied and pasted on the document
猜你喜欢
随机推荐
读懂 互联网巨头 【中台之战】 以及 中台 发展思维
oracle 更新和删除数据
Redux quick start
Won't you just stick to 62 days? Sum of words
Mongo Shell交互式命令窗口
Delete blank pages in word documents
Sign the college entrance examination
[c language] PTA 7-48 find the number of combinations
Won't you just stick to 69 days? Merge range
Pytoch automatic mixing accuracy (AMP) training
[C] PTA 6-8 finding the height of binary tree
Dasctf2022.07 empowerment competition
异常解决:cococaption包出现找不到edu.stanford.nlp.semgraph.semgrex.SemgrexPattern错误
恒星科通邀您“湘”约第24届中国高速公路信息化大会暨技术产品展示会
删除word文档中的空白页
Update learning materials daily
Webrtc realizes simple audio and video call function
ssm整合增删改查
Use of torch.optim optimizer in pytorch
14. Haproxy+kept load balancing and high availability








