当前位置:网站首页>常用函数detect_image/predict
常用函数detect_image/predict
2022-07-07 02:14:00 【逆夏11111】
detect_image
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image):
start = timer()
image_shape = np.array(np.shape(image)[0:2])
crop_img, x_offset, y_offset = letterbox_image(image, (self.model_image_size[0], self.model_image_size[1]))
photo = np.array(crop_img, dtype=np.float64)
# 图片预处理,归一化
photo = preprocess_input(np.reshape(photo, [1, self.model_image_size[0], self.model_image_size[1], 3]))
preds = self.ssd_model.predict(photo)
# 将预测结果进行解码
results = self.bbox_util.detection_out(preds, confidence_threshold=self.confidence)
# print(results)
if len(results[0]) <= 0:
return image, []
# 筛选出其中得分高于confidence的框
det_label = results[0][:, 0]
det_conf = results[0][:, 1]
det_xmin, det_ymin, det_xmax, det_ymax = results[0][:, 2], results[0][:, 3], results[0][:, 4], results[0][:, 5]
top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence]
top_conf = det_conf[top_indices]
top_label_indices = det_label[top_indices].tolist()
top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices], -1), np.expand_dims(
det_ymin[top_indices], -1), np.expand_dims(det_xmax[top_indices], -1), np.expand_dims(det_ymax[top_indices],
-1)
# 去掉灰条
boxes = retinanet_correct_boxes(top_ymin, top_xmin, top_ymax, top_xmax,
np.array([self.model_image_size[0], self.model_image_size[1]]), image_shape)
font = ImageFont.truetype(font='model_data/simhei.ttf',
size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))
thickness = (np.shape(image)[0] + np.shape(image)[1]) // self.model_image_size[0]
record_pred = []
for i, c in enumerate(top_label_indices):
predicted_class = self.class_names[int(c) - 1]
score = top_conf[i]
top, left, bottom, right = boxes[i]
top = top - 5
left = left - 5
bottom = bottom + 5
right = right + 5
top = max(0, np.floor(top + 0.5).astype('int32'))
left = max(0, np.floor(left + 0.5).astype('int32'))
bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32'))
right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32'))
print(top, left, bottom, right)
# 画框框
label = '{} {:.2f}'.format(predicted_class, score)
draw = ImageDraw.Draw(image)
label_size = draw.textsize(label, font)
# label = label.encode('utf-8') # 这里如果要变成utf8
print(label, (left, top), (right, bottom))
record_pred.append(label + " %s %s %s %s" % (left, top, right, bottom))
if top - label_size[1] >= 0:
text_origin = np.array([left, top - label_size[1]])
else:
text_origin = np.array([left, top + 1])
for i in range(thickness):
draw.rectangle(
[left + i, top + i, right - i, bottom - i],
outline=self.colors[int(c) - 1])
draw.rectangle(
[tuple(text_origin), tuple(text_origin + label_size)],
fill=self.colors[int(c) - 1])
draw.text(text_origin, label, fill=(0, 0, 0), font=font) # 如果前面label换成utf-8这里应该转变回来:str(label,'UTF-8')
del draw
end = timer()
print("detect time:", end - start)
return image, record_pred
不同模型要换的只有: boxes = ssd_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.model_image_size[0],self.model_image_size[1]]),image_shape)
去掉灰条的部分,ssd模型就是ssd_correct,retinanet模型就是retinanet_correct。。。依次类推
predict
from keras.layers import Input
from retinanet import Retinanet
from PIL import Image
import os
retinanet = Retinanet()
img_path = "data/train1/domain1" # TODO:输入的要测试的图片文件夹
predict_path = "mAP/predict/"
filename_list = os.listdir(img_path)
for filename in filename_list:
print(filename)
img = input(filename)
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
else:
r_image,record_pred = retinanet.detect_image(image)
with open(predict_path + filename.split(".")[0] + '.txt', "w") as f:
f.write("\n".join(record_pred))
# r_image.save("mAP/predict_images/"+filename)
retinanet.close_session()
边栏推荐
- 哈趣投影黑馬之姿,僅用半年强勢突圍千元投影儀市場!
- HKUST & MsrA new research: on image to image conversion, fine tuning is all you need
- Apache ab 压力测试
- c面试 加密程序:由键盘输入明文,通过加密程序转换成密文并输出到屏幕上。
- 安装VMmare时候提示hyper-v / device defender 侧通道安全性
- Laravel uses Tencent cloud cos5 full tutorial
- 软件测试到了35岁,真的就干不动了吗?
- ceres-solver和g2o性能比较
- JMeter function assistant - random value, random string, fixed value random extraction
- 骑士战胜魔王(背包&dp)
猜你喜欢
Pinduoduo lost the lawsuit: "bargain for free" infringed the right to know but did not constitute fraud, and was sentenced to pay 400 yuan
2022Android面试必备知识点,一文全面总结
matlab / ENVI 主成分分析实现及结果分析
哈趣投影黑馬之姿,僅用半年强勢突圍千元投影儀市場!
How to set up in touch designer 2022 to solve the problem that leap motion is not recognized?
693. Travel sequencing
"Parse" focalloss to solve the problem of data imbalance
3428. Put apples
string(讲解)
ICML 2022 | 探索语言模型的最佳架构和训练方法
随机推荐
How to keep accounts of expenses in life
ST表预处理时的数组证明
C语言面试 写一个函数查找两个字符串中的第一个公共字符串
Navicat导入15G数据报错 【2013 - Lost connection to MySQL server during query】 【1153:Got a packet bigger】
3428. Put apples
Crudini 配置文件编辑工具
哈趣投影黑马之姿,仅用半年强势突围千元投影仪市场!
UIC (configuration UI Engineering) public file library adds 7 industry materials
2022Android面试必备知识点,一文全面总结
rt-thread 中对 hardfault 的处理
A program lets you understand what static inner classes, local inner classes, and anonymous inner classes are
[FPGA] EEPROM based on I2C
Shared memory for interprocess communication
骑士战胜魔王(背包&dp)
dolphinscheduler3. X local startup
拼多多败诉:“砍价免费拿”侵犯知情权但不构成欺诈,被判赔400元
C language interview to write a function to find the first public string in two strings
Several key steps of software testing, you need to know
Overview of FlexRay communication protocol
计算模型 FPS