当前位置:网站首页>tensorflow2-savedmodel convert to pb(frozen_graph)
tensorflow2-savedmodel convert to pb(frozen_graph)
2022-07-01 14:37:00 【哗啦呼啦嘿】
1.frozen graph:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
frozen_out_path = '/data1/gyx/QR/SSD_Tensorflow2.0-master/convert/pb'
frozen_graph_filename = "frozen_graph_ssd_multi.pb"
saved_model_path= '/data1/gyx/QR/SSD_Tensorflow2.0-master/result/weight/ssd_multi/pb/'
# model = tf.keras.models.load_model(saved_model_path)
# model.summary()
# images=tf.random.uniform((1, 300, 300, 3))
# print(model.predict(images)[0].shape) #(1, 8732, 5)
# print(model.predict(images)[1].shape) #(1, 8732, 4)
# 定义输入格式
img = tf.random.uniform((1, 300, 300, 3))
# 加载模型
network = tf.keras.models.load_model(saved_model_path)
# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: network(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(img.shape, img.dtype)) # (1, 300, 300, 3) <dtype: 'uint8'>
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ") #[<tf.Tensor 'x:0' shape=(1, 300, 300, 3) dtype=float32>]
print(frozen_func.inputs)
print("Frozen model outputs: ") #[<tf.Tensor 'Identity:0' shape=(1, 8732, 5) dtype=float32>, <tf.Tensor 'Identity_1:0' shape=(1, 8732, 4) dtype=float32>]
print(frozen_func.outputs)
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir=frozen_out_path,
name=frozen_graph_filename,
as_text=False)
# import netron
# netron.start(os.path.join(frozen_out_path,frozen_graph_filename))
2. predict pb
import cv2
import tensorflow.compat.v1 as tf
import numpy as np
import os
import cv2
import copy
import config as c
from utils.aug_utils import color_normalize
from utils.anchor_utils import generate_anchors, from_offset_to_box,from_offset_to_box_nms
from utils.eval_utils import show_box,show_multibox
tf.disable_v2_behavior()
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
path='/data1/gyx/QR/SSD_Tensorflow2.0-master/convert/pb/frozen_graph_ssd_multi.pb'
image_path = '/data1/gyx/QR/SSD_Tensorflow2.0-master/test_pic/20210330101245.bmp'
img_cv2 = cv2.imread(image_path)
height, width, _ = np.shape(img_cv2)
#### 第一种图像归一化方法 ####
# (1, 3, 300, 300)
# blob = cv2.dnn.blobFromImage(img_cv2,
# scalefactor=1.0 / 255,
# size=(300, 300),
# mean=(0, 0, 0),
# swapRB=False,
# crop=False)
# blob = cv2.dnn.blobFromImage(img_cv2,
# scalefactor=1.0 / 255,
# size=(300, 300),
# mean=(103.939/ 255, 116.779/ 255, 123.68/ 255),
# swapRB=False,
# crop=False)
# (1, 300, 300, 3)
# input_image = np.transpose(blob, (0,2,3,1))
#### 第二种图像归一化方法(按照训练模型格式) ####
input_image = np.array([color_normalize(cv2.resize(copy.copy(img_cv2), tuple(c.input_shape[:2])))], dtype=np.float32)
print(input_image.shape)
with tf.gfile.FastGFile(path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
out = sess.run([sess.graph.get_tensor_by_name('Identity:0'),
sess.graph.get_tensor_by_name('Identity_1:0')
],feed_dict={'x:0': input_image})
# out={'detection_classes':out[0],'detection_boxes':out[1]}
# print(out)
anchors = generate_anchors()
cls_pred, loc_pred = out[0],out[1]
print(cls_pred.shape, loc_pred.shape)
# print(type(loc_pred)) #<class 'numpy.ndarray'>
# cls_pred = tf.convert_to_tensor(np.array(cls_pred), dtype='float32')
# loc_pred = tf.convert_to_tensor(np.array(loc_pred), dtype='float32')
# print(cls_pred)
# print(loc_pred)
boxes, scores, labels = from_offset_to_box(loc_pred[0], cls_pred[0], anchors,
anchor_belongs_to_one_class=True, score_threshold=0.1)
print(boxes, scores, labels)
# print(boxes)
boxes_new=[]
score_new=[]
label_new=[]
for box, score, label in zip(boxes, scores, labels):
box[0] = box[0] / c.input_shape[1] * width # left
box[1] = box[1] / c.input_shape[0] * height # top
box[2] = box[2] / c.input_shape[1] * width # right
box[3] = box[3] / c.input_shape[0] * height # bottom
print('image: {}\nclass: {}\nconfidence: {:.4f}\n'.format(image_path, c.class_list[label], score))
boxes_new.append(box)
score_new.append(score)
label_new.append(c.class_list[label])
show_multibox(img_cv2, boxes_new,score_new,label_new)
边栏推荐
- Sqlachemy common operations
- SWT / anr problem - how to capture performance trace
- 微服务开发步骤(nacos)
- Leetcode (69) -- square root of X
- Pat 1065 a+b and C (64bit) (20 points) (16 points)
- 【修复版】仿我爱看电影网站模板/海洋CMS影视系统模板
- Yyds dry goods inventory hcie security day13: firewall dual machine hot standby experiment (I) firewall direct deployment, uplink and downlink connection switches
- Research Report on the development trend and competitive strategy of the global indexable milling cutter industry
- 原来程序员搞私活这么赚钱?真的太香了
- 对于编程思想和能力有重大提升的书有哪些?
猜你喜欢

Blog recommendation | in depth study of message segmentation in pulsar

Tdengine connector goes online Google Data Studio app store

How to view the state-owned enterprises have unloaded Microsoft office and switched to Kingsoft WPS?

sqlilabs less13

被裁三個月,面試到處碰壁,心態已經開始崩了

Semiconductor foundation of binary realization principle

2022-2-15 learning the imitation Niuke project - Section 3 post details
![[IOT completion. Part 2] stm32+ smart cloud aiot+ laboratory security monitoring system](/img/b2/e8f81ecda6f5f4fc65501aaf9f13cf.gif)
[IOT completion. Part 2] stm32+ smart cloud aiot+ laboratory security monitoring system

Build your own website (21)

Salesforce、约翰霍普金斯、哥大 | ProGen2: 探索蛋白语言模型的边界
随机推荐
leetcode622. Design cycle queue (C language)
Details of appium key knowledge
Yyds dry goods inventory hcie security day13: firewall dual machine hot standby experiment (I) firewall direct deployment, uplink and downlink connection switches
当主程架构游戏的时候,防止到处调用减少耦合性,怎么开放接口给其他人调用呢?
SQLAchemy 常用操作
【牛客网刷题系列 之 Verilog快速入门】~ 使用函数实现数据大小端转换
Use of Oracle database objects
Leetcode (69) -- square root of X
2022-2-15 learning xiangniuke project - Section 1 filtering sensitive words
微服务开发步骤(nacos)
Research Report on the development trend and competitive strategy of the global display filter industry
SWT / anr problem - how to open binder trace (bindertraces) when sending anr / SWT
Today, with the popularity of micro services, how does service mesh exist?
使用net core 6 c# 的 NPOI 包,读取excel..xlsx单元格内的图片,并存储到指定服务器
光環效應——誰說頭上有光的就算英雄
Blog recommendation | in depth study of message segmentation in pulsar
Halo effect - who says that those with light on their heads are heroes
Don't want to knock the code? Here comes the chance
Research Report on the development trend and competitive strategy of the global traditional computer industry
博文推荐 | 深入研究 Pulsar 中的消息分块