当前位置:网站首页>基于PyTorch和Fast RCNN快速实现目标识别
基于PyTorch和Fast RCNN快速实现目标识别
2022-07-06 06:35:00 【GIS开发者】
Faster RCNN,相对于R-CNN在结构上,Faster RCNN已经将特征抽取(feature extraction),proposal提取,bounding box regression(rect refine),classification都整合在了一个网络中,使得综合性能有较大提高,在检测速度方面尤为明显。
PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。相对于TensorFlow更加轻量化,更适合于科研和小型项目的使用。
这里介绍一个简单的例子,基于Fast RCNN算法和PyTorch快速实现目标识别。这里使用的是coco数据集已经训练好的在线的模型,运行起来比较简单。
代码
from PIL import Image
import matplotlib.pyplot as plt
# pip install -U matplotlib
import torch
# pip install pytorch
import torchvision.transforms as T
import torchvision
# pip install torchvision
import numpy as np
import cv2
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# pip install opencv-python
# 下载已经训练好的模型
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
def get_prediction(img_path, threshold):
img = Image.open(img_path)
# 转换一个PIL库的图片或者numpy的数组为tensor张量类型;转换从[0,255]->[0,1]
transform = T.Compose([T.ToTensor()])
img = transform(img)
pred = model([img])
print(pred[0]['labels'].numpy())
# 类别提取
pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
# 坐标提取
pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())]
# 找出符合相似度要求的
pred_score = list(pred[0]['scores'].detach().numpy())
pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
pred_boxes = pred_boxes[:pred_t + 1]
pred_class = pred_class[:pred_t + 1]
print("pred_class:", pred_class)
print("pred_boxes:", pred_boxes)
return pred_boxes, pred_class
def object_detection_api(img_path, threshold=0.5, rect_th=3, text_size=3, text_th=3):
boxes, pred_cls = get_prediction(img_path, threshold)
img = cv2.imread(img_path)
# 转换为RGB图像
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
for i in range(len(boxes)):
# 根据坐标圈出目标
cv2.rectangle(img, (int(boxes[i][0][0]), int(boxes[i][0][1])), (int(boxes[i][1][0]), int(boxes[i][1][1])),
color=(0, 255, 0),
thickness=rect_th)
# 标注类别
cv2.putText(img, pred_cls[i], (int(boxes[i][0][0]), int(boxes[i][0][1])), cv2.FONT_HERSHEY_SIMPLEX, text_size,
(0, 255, 0), thickness=text_th)
plt.imshow(img)
plt.show()
if __name__ == '__main__':
object_detection_api(img_path=r"C:\Users\hanbo\Pictures\dog.jpg")
结果
示例1
示例2
边栏推荐
- The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
- 端午节快乐Wish Dragon Boat Festival is happy
- My daily learning records / learning methods
- Successfully solved typeerror: data type 'category' not understood
- The registration password of day 239/300 is 8~14 alphanumeric and punctuation, and at least 2 checks are included
- Lecture 8: 1602 LCD (Guo Tianxiang)
- 国际经贸合同翻译 中译英怎样效果好
- LeetCode - 152 乘积最大子数组
- 查询字段个数
- It is necessary to understand these characteristics in translating subtitles of film and television dramas
猜你喜欢
Biomedical English contract translation, characteristics of Vocabulary Translation
mysql的基础命令
Basic commands of MySQL
翻译影视剧字幕,这些特点务必要了解
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
[ 英语 ] 语法重塑 之 英语学习的核心框架 —— 英语兔学习笔记(1)
Lesson 7 tensorflow realizes convolutional neural network
[ 英語 ] 語法重塑 之 動詞分類 —— 英語兔學習筆記(2)
My creation anniversary
LeetCode每日一题(971. Flip Binary Tree To Match Preorder Traversal)
随机推荐
LeetCode 729. My schedule I
Advanced MySQL: Basics (1-4 Lectures)
Lecture 8: 1602 LCD (Guo Tianxiang)
国产游戏国际化离不开专业的翻译公司
Automated test environment configuration
论文摘要翻译,多语言纯人工翻译
University of Manchester | dda3c: collaborative distributed deep reinforcement learning in swarm agent systems
How do programmers remember code and programming language?
国际经贸合同翻译 中译英怎样效果好
自动化测试环境配置
Today's summer solstice
Facebook AI & Oxford proposed a video transformer with "track attention" to perform SOTA in video action recognition tasks
LeetCode每日一题(1870. Minimum Speed to Arrive on Time)
Delete external table source data
[English] Verb Classification of grammatical reconstruction -- English rabbit learning notes (2)
Apple has open source, but what about it?
Private cloud disk deployment
利用快捷方式-LNK-上线CS
ECS accessKey key disclosure and utilization
Day 246/300 SSH connection prompt "remote host identification has changed!"