当前位置:网站首页>tensorflow2-savedmodel convert to pb(frozen_graph)
tensorflow2-savedmodel convert to pb(frozen_graph)
2022-07-01 14:41:00 【Hula Hula hey】
1.frozen graph:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
frozen_out_path = '/data1/gyx/QR/SSD_Tensorflow2.0-master/convert/pb'
frozen_graph_filename = "frozen_graph_ssd_multi.pb"
saved_model_path= '/data1/gyx/QR/SSD_Tensorflow2.0-master/result/weight/ssd_multi/pb/'
# model = tf.keras.models.load_model(saved_model_path)
# model.summary()
# images=tf.random.uniform((1, 300, 300, 3))
# print(model.predict(images)[0].shape) #(1, 8732, 5)
# print(model.predict(images)[1].shape) #(1, 8732, 4)
# Define the input format
img = tf.random.uniform((1, 300, 300, 3))
# Load model
network = tf.keras.models.load_model(saved_model_path)
# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: network(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(img.shape, img.dtype)) # (1, 300, 300, 3) <dtype: 'uint8'>
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ") #[<tf.Tensor 'x:0' shape=(1, 300, 300, 3) dtype=float32>]
print(frozen_func.inputs)
print("Frozen model outputs: ") #[<tf.Tensor 'Identity:0' shape=(1, 8732, 5) dtype=float32>, <tf.Tensor 'Identity_1:0' shape=(1, 8732, 4) dtype=float32>]
print(frozen_func.outputs)
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir=frozen_out_path,
name=frozen_graph_filename,
as_text=False)
# import netron
# netron.start(os.path.join(frozen_out_path,frozen_graph_filename))
2. predict pb
import cv2
import tensorflow.compat.v1 as tf
import numpy as np
import os
import cv2
import copy
import config as c
from utils.aug_utils import color_normalize
from utils.anchor_utils import generate_anchors, from_offset_to_box,from_offset_to_box_nms
from utils.eval_utils import show_box,show_multibox
tf.disable_v2_behavior()
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
path='/data1/gyx/QR/SSD_Tensorflow2.0-master/convert/pb/frozen_graph_ssd_multi.pb'
image_path = '/data1/gyx/QR/SSD_Tensorflow2.0-master/test_pic/20210330101245.bmp'
img_cv2 = cv2.imread(image_path)
height, width, _ = np.shape(img_cv2)
#### The first image normalization method ####
# (1, 3, 300, 300)
# blob = cv2.dnn.blobFromImage(img_cv2,
# scalefactor=1.0 / 255,
# size=(300, 300),
# mean=(0, 0, 0),
# swapRB=False,
# crop=False)
# blob = cv2.dnn.blobFromImage(img_cv2,
# scalefactor=1.0 / 255,
# size=(300, 300),
# mean=(103.939/ 255, 116.779/ 255, 123.68/ 255),
# swapRB=False,
# crop=False)
# (1, 300, 300, 3)
# input_image = np.transpose(blob, (0,2,3,1))
#### The second image normalization method ( According to the training model format ) ####
input_image = np.array([color_normalize(cv2.resize(copy.copy(img_cv2), tuple(c.input_shape[:2])))], dtype=np.float32)
print(input_image.shape)
with tf.gfile.FastGFile(path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
out = sess.run([sess.graph.get_tensor_by_name('Identity:0'),
sess.graph.get_tensor_by_name('Identity_1:0')
],feed_dict={'x:0': input_image})
# out={'detection_classes':out[0],'detection_boxes':out[1]}
# print(out)
anchors = generate_anchors()
cls_pred, loc_pred = out[0],out[1]
print(cls_pred.shape, loc_pred.shape)
# print(type(loc_pred)) #<class 'numpy.ndarray'>
# cls_pred = tf.convert_to_tensor(np.array(cls_pred), dtype='float32')
# loc_pred = tf.convert_to_tensor(np.array(loc_pred), dtype='float32')
# print(cls_pred)
# print(loc_pred)
boxes, scores, labels = from_offset_to_box(loc_pred[0], cls_pred[0], anchors,
anchor_belongs_to_one_class=True, score_threshold=0.1)
print(boxes, scores, labels)
# print(boxes)
boxes_new=[]
score_new=[]
label_new=[]
for box, score, label in zip(boxes, scores, labels):
box[0] = box[0] / c.input_shape[1] * width # left
box[1] = box[1] / c.input_shape[0] * height # top
box[2] = box[2] / c.input_shape[1] * width # right
box[3] = box[3] / c.input_shape[0] * height # bottom
print('image: {}\nclass: {}\nconfidence: {:.4f}\n'.format(image_path, c.class_list[label], score))
boxes_new.append(box)
score_new.append(score)
label_new.append(c.class_list[label])
show_multibox(img_cv2, boxes_new,score_new,label_new)
边栏推荐
- 【14. 区间和(离散化)】
- [15. Interval consolidation]
- 【阶段人生总结】放弃考研,参与到工作中,已经顺利毕业了,昨天刚领毕业证
- tensorflow2-savedmodel convert to tflite
- Sorting learning sorting
- [repair version] imitating the template of I love watching movies website / template of ocean CMS film and television system
- Provincial election + noi Part XI others
- 用对场景,事半功倍!TDengine 的窗口查询功能及使用场景全介绍
- NPDP产品经理国际认证报名有什么要求?
- sqlilabs less-8
猜你喜欢

241. 为运算表达式设计优先级

SQLAchemy 常用操作

Build your own website (21)
![[commercial terminal simulation solution] Shanghai daoning brings you Georgia introduction, trial and tutorial](/img/44/b65aaf11b1e632f2dab55b6fc699f6.jpg)
[commercial terminal simulation solution] Shanghai daoning brings you Georgia introduction, trial and tutorial
![[R language data science]: common evaluation indicators of machine learning](/img/c8/dbfb041fa72799fae1892fe8ac0050.png)
[R language data science]: common evaluation indicators of machine learning

关于重载运算符的再整理

sqlilabs less-11~12

Summary of leetcode's dynamic programming 5

JVM performance tuning and practical basic theory part II

Using CMD to repair and recover virus infected files
随机推荐
Generate random numbers (4-bit, 6-bit)
被裁三個月,面試到處碰壁,心態已經開始崩了
Salesforce、约翰霍普金斯、哥大 | ProGen2: 探索蛋白语言模型的边界
一波三折,终于找到src漏洞挖掘的方法了【建议收藏】
光環效應——誰說頭上有光的就算英雄
Research Report on the development trend and competitive strategy of the global camera filter bracket industry
tensorflow2-savedmodel convert to tflite
百度上找的期货公司安全吗?期货公司怎么确定正规
TDengine 连接器上线 Google Data Studio 应用商店
博文推荐 | 深入研究 Pulsar 中的消息分块
Research Report on the development trend and competitive strategy of the global traditional computer industry
111. Minimum depth of binary tree
Research Report on development trend and competitive strategy of global consumer glassware industry
2022-2-15 learning the imitation Niuke project - post in Section 2
深度合作 | 涛思数据携手长虹佳华为中国区客户提供 TDengine 强大企业级产品与完善服务保障
MIT团队使用图神经网络,加速无定形聚合物电解质筛选,促进下一代锂电池技术开发
Opencv mat class
Vnctf2022 open web gocalc0
Microservice development steps (Nacos)
Summary of leetcode's dynamic programming 5