当前位置:网站首页>常用函数detect_image/predict
常用函数detect_image/predict
2022-07-07 02:14:00 【逆夏11111】
detect_image
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image):
start = timer()
image_shape = np.array(np.shape(image)[0:2])
crop_img, x_offset, y_offset = letterbox_image(image, (self.model_image_size[0], self.model_image_size[1]))
photo = np.array(crop_img, dtype=np.float64)
# 图片预处理,归一化
photo = preprocess_input(np.reshape(photo, [1, self.model_image_size[0], self.model_image_size[1], 3]))
preds = self.ssd_model.predict(photo)
# 将预测结果进行解码
results = self.bbox_util.detection_out(preds, confidence_threshold=self.confidence)
# print(results)
if len(results[0]) <= 0:
return image, []
# 筛选出其中得分高于confidence的框
det_label = results[0][:, 0]
det_conf = results[0][:, 1]
det_xmin, det_ymin, det_xmax, det_ymax = results[0][:, 2], results[0][:, 3], results[0][:, 4], results[0][:, 5]
top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence]
top_conf = det_conf[top_indices]
top_label_indices = det_label[top_indices].tolist()
top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices], -1), np.expand_dims(
det_ymin[top_indices], -1), np.expand_dims(det_xmax[top_indices], -1), np.expand_dims(det_ymax[top_indices],
-1)
# 去掉灰条
boxes = retinanet_correct_boxes(top_ymin, top_xmin, top_ymax, top_xmax,
np.array([self.model_image_size[0], self.model_image_size[1]]), image_shape)
font = ImageFont.truetype(font='model_data/simhei.ttf',
size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))
thickness = (np.shape(image)[0] + np.shape(image)[1]) // self.model_image_size[0]
record_pred = []
for i, c in enumerate(top_label_indices):
predicted_class = self.class_names[int(c) - 1]
score = top_conf[i]
top, left, bottom, right = boxes[i]
top = top - 5
left = left - 5
bottom = bottom + 5
right = right + 5
top = max(0, np.floor(top + 0.5).astype('int32'))
left = max(0, np.floor(left + 0.5).astype('int32'))
bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32'))
right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32'))
print(top, left, bottom, right)
# 画框框
label = '{} {:.2f}'.format(predicted_class, score)
draw = ImageDraw.Draw(image)
label_size = draw.textsize(label, font)
# label = label.encode('utf-8') # 这里如果要变成utf8
print(label, (left, top), (right, bottom))
record_pred.append(label + " %s %s %s %s" % (left, top, right, bottom))
if top - label_size[1] >= 0:
text_origin = np.array([left, top - label_size[1]])
else:
text_origin = np.array([left, top + 1])
for i in range(thickness):
draw.rectangle(
[left + i, top + i, right - i, bottom - i],
outline=self.colors[int(c) - 1])
draw.rectangle(
[tuple(text_origin), tuple(text_origin + label_size)],
fill=self.colors[int(c) - 1])
draw.text(text_origin, label, fill=(0, 0, 0), font=font) # 如果前面label换成utf-8这里应该转变回来:str(label,'UTF-8')
del draw
end = timer()
print("detect time:", end - start)
return image, record_pred
不同模型要换的只有: boxes = ssd_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.model_image_size[0],self.model_image_size[1]]),image_shape)
去掉灰条的部分,ssd模型就是ssd_correct,retinanet模型就是retinanet_correct。。。依次类推
predict
from keras.layers import Input
from retinanet import Retinanet
from PIL import Image
import os
retinanet = Retinanet()
img_path = "data/train1/domain1" # TODO:输入的要测试的图片文件夹
predict_path = "mAP/predict/"
filename_list = os.listdir(img_path)
for filename in filename_list:
print(filename)
img = input(filename)
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
else:
r_image,record_pred = retinanet.detect_image(image)
with open(predict_path + filename.split(".")[0] + '.txt', "w") as f:
f.write("\n".join(record_pred))
# r_image.save("mAP/predict_images/"+filename)
retinanet.close_session()
边栏推荐
- Markdown 并排显示图片
- ICML 2022 | explore the best architecture and training method of language model
- dolphinscheduler3.x本地启动
- Audio distortion analysis of DSP and DAC based on adau1452
- [SOC FPGA] peripheral PIO button lights up
- C面试24. (指针)定义一个含有20个元素的double型数组a
- c语言(结构体)定义一个User结构体,含以下字段:
- Which foreign language periodicals are famous in geology?
- JVM 全面深入
- Force deduction 62 different paths (the number of all paths from the upper left to the lower right of the matrix) (dynamic planning)
猜你喜欢
【从零开始】win10系统部署Yolov5详细过程(CPU,无GPU)
Can't you really do it when you are 35 years old?
Common problems of caching in high concurrency scenarios
3531. Huffman tree
拼多多败诉:“砍价免费拿”侵犯知情权但不构成欺诈,被判赔400元
The difference between string constants and string objects when allocating memory
Learning notes | data Xiaobai uses dataease to make a large data screen
Handling hardfault in RT thread
屏幕程序用串口无法调试情况
[opencv] morphological filtering (2): open operation, morphological gradient, top hat, black hat
随机推荐
Abnova 膜蛋白脂蛋白体技术及类别展示
港科大&MSRA新研究:关于图像到图像转换,Fine-tuning is all you need
PostgreSQL database timescaledb function time_ bucket_ Gapfill() error resolution and license replacement
面试中有哪些经典的数据库问题?
JVM 全面深入
基于FPGA的VGA协议实现
Shared memory for interprocess communication
直击2022ECDC萤石云开发者大会:携手千百行业加速智能升级
「解析」FocalLoss 解决数据不平衡问题
How to keep accounts of expenses in life
dolphinscheduler3. X local startup
Tkinter window selects PCD file and displays point cloud (open3d)
哈趣投影黑馬之姿,僅用半年强勢突圍千元投影儀市場!
对称的二叉树【树的遍历】
[SOC FPGA] custom IP PWM breathing lamp
[start from scratch] detailed process of deploying yolov5 in win10 system (CPU, no GPU)
Programmers' daily | daily anecdotes
请问如何查一篇外文文献的DOI号?
Handling hardfault in RT thread
基本Dos命令