当前位置:网站首页>tensorflow2-savedmodel convert to pb(frozen_graph)
tensorflow2-savedmodel convert to pb(frozen_graph)
2022-07-01 14:41:00 【Hula Hula hey】
1.frozen graph:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
frozen_out_path = '/data1/gyx/QR/SSD_Tensorflow2.0-master/convert/pb'
frozen_graph_filename = "frozen_graph_ssd_multi.pb"
saved_model_path= '/data1/gyx/QR/SSD_Tensorflow2.0-master/result/weight/ssd_multi/pb/'
# model = tf.keras.models.load_model(saved_model_path)
# model.summary()
# images=tf.random.uniform((1, 300, 300, 3))
# print(model.predict(images)[0].shape) #(1, 8732, 5)
# print(model.predict(images)[1].shape) #(1, 8732, 4)
# Define the input format
img = tf.random.uniform((1, 300, 300, 3))
# Load model
network = tf.keras.models.load_model(saved_model_path)
# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: network(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(img.shape, img.dtype)) # (1, 300, 300, 3) <dtype: 'uint8'>
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ") #[<tf.Tensor 'x:0' shape=(1, 300, 300, 3) dtype=float32>]
print(frozen_func.inputs)
print("Frozen model outputs: ") #[<tf.Tensor 'Identity:0' shape=(1, 8732, 5) dtype=float32>, <tf.Tensor 'Identity_1:0' shape=(1, 8732, 4) dtype=float32>]
print(frozen_func.outputs)
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir=frozen_out_path,
name=frozen_graph_filename,
as_text=False)
# import netron
# netron.start(os.path.join(frozen_out_path,frozen_graph_filename))
2. predict pb
import cv2
import tensorflow.compat.v1 as tf
import numpy as np
import os
import cv2
import copy
import config as c
from utils.aug_utils import color_normalize
from utils.anchor_utils import generate_anchors, from_offset_to_box,from_offset_to_box_nms
from utils.eval_utils import show_box,show_multibox
tf.disable_v2_behavior()
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
path='/data1/gyx/QR/SSD_Tensorflow2.0-master/convert/pb/frozen_graph_ssd_multi.pb'
image_path = '/data1/gyx/QR/SSD_Tensorflow2.0-master/test_pic/20210330101245.bmp'
img_cv2 = cv2.imread(image_path)
height, width, _ = np.shape(img_cv2)
#### The first image normalization method ####
# (1, 3, 300, 300)
# blob = cv2.dnn.blobFromImage(img_cv2,
# scalefactor=1.0 / 255,
# size=(300, 300),
# mean=(0, 0, 0),
# swapRB=False,
# crop=False)
# blob = cv2.dnn.blobFromImage(img_cv2,
# scalefactor=1.0 / 255,
# size=(300, 300),
# mean=(103.939/ 255, 116.779/ 255, 123.68/ 255),
# swapRB=False,
# crop=False)
# (1, 300, 300, 3)
# input_image = np.transpose(blob, (0,2,3,1))
#### The second image normalization method ( According to the training model format ) ####
input_image = np.array([color_normalize(cv2.resize(copy.copy(img_cv2), tuple(c.input_shape[:2])))], dtype=np.float32)
print(input_image.shape)
with tf.gfile.FastGFile(path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
out = sess.run([sess.graph.get_tensor_by_name('Identity:0'),
sess.graph.get_tensor_by_name('Identity_1:0')
],feed_dict={'x:0': input_image})
# out={'detection_classes':out[0],'detection_boxes':out[1]}
# print(out)
anchors = generate_anchors()
cls_pred, loc_pred = out[0],out[1]
print(cls_pred.shape, loc_pred.shape)
# print(type(loc_pred)) #<class 'numpy.ndarray'>
# cls_pred = tf.convert_to_tensor(np.array(cls_pred), dtype='float32')
# loc_pred = tf.convert_to_tensor(np.array(loc_pred), dtype='float32')
# print(cls_pred)
# print(loc_pred)
boxes, scores, labels = from_offset_to_box(loc_pred[0], cls_pred[0], anchors,
anchor_belongs_to_one_class=True, score_threshold=0.1)
print(boxes, scores, labels)
# print(boxes)
boxes_new=[]
score_new=[]
label_new=[]
for box, score, label in zip(boxes, scores, labels):
box[0] = box[0] / c.input_shape[1] * width # left
box[1] = box[1] / c.input_shape[0] * height # top
box[2] = box[2] / c.input_shape[1] * width # right
box[3] = box[3] / c.input_shape[0] * height # bottom
print('image: {}\nclass: {}\nconfidence: {:.4f}\n'.format(image_path, c.class_list[label], score))
boxes_new.append(box)
score_new.append(score)
label_new.append(c.class_list[label])
show_multibox(img_cv2, boxes_new,score_new,label_new)
边栏推荐
- Chapter 4 of getting started with MySQL: creation, modification and deletion of data tables
- 既不是研发顶尖高手,也不是销售大牛,为何偏偏获得 2 万 RMB 的首个涛思文化奖?
- 如何看待国企纷纷卸载微软Office改用金山WPS?
- Research Report on the development trend and competitive strategy of the global diamond suspension industry
- Salesforce、约翰霍普金斯、哥大 | ProGen2: 探索蛋白语言模型的边界
- 手把手带你入门 API 开发
- Research Report on the development trend and competitive strategy of the global axis measurement system industry
- Sorting learning sorting
- [getting started with Django] 13 page Association MySQL "multi" field table (check)
- Semiconductor foundation of binary realization principle
猜你喜欢
![[Verilog quick start of Niuke question series] ~ use functions to realize data size conversion](/img/e1/d35e1d382e0e945849010941b219d3.png)
[Verilog quick start of Niuke question series] ~ use functions to realize data size conversion

微服务开发步骤(nacos)

Salesforce、约翰霍普金斯、哥大 | ProGen2: 探索蛋白语言模型的边界
![[14. Interval sum (discretization)]](/img/e5/8b29aca7068a6385e8ce90c2742c37.png)
[14. Interval sum (discretization)]
![[dynamic programming] interval dp:p1005 matrix retrieval](/img/c9/2091f51b905d2c0ebc978dab3d34d3.jpg)
[dynamic programming] interval dp:p1005 matrix retrieval
![[R language data science]: common evaluation indicators of machine learning](/img/c8/dbfb041fa72799fae1892fe8ac0050.png)
[R language data science]: common evaluation indicators of machine learning

【商业终端仿真解决方案】上海道宁为您带来Georgia介绍、试用、教程

In depth cooperation | Taosi data cooperates with changhongjia Huawei customers in China to provide tdengine with powerful enterprise level products and perfect service guarantee

Tdengine connector goes online Google Data Studio app store

Problem note - Oracle 11g uninstall
随机推荐
How can we protect our passwords?
Use of Oracle database objects
对于编程思想和能力有重大提升的书有哪些?
2022-2-15 learning the imitation Niuke project - Section 3 post details
户外LED显示屏应该考虑哪些问题?
Pat 1065 a+b and C (64bit) (20 points) (16 points)
Provincial election + noi Part 10 probability statistics and polynomials
sqlilabs less13
Why did you win the first Taosi culture award of 20000 RMB if you are neither a top R & D expert nor a sales Daniel?
力扣解法汇总241-为运算表达式设计优先级
sqlilabs less10
博文推荐 | 深入研究 Pulsar 中的消息分块
NPDP能给产品经理带来什么价值?你都知道了吗?
Yyds dry goods inventory hcie security day13: firewall dual machine hot standby experiment (I) firewall direct deployment, uplink and downlink connection switches
被裁三個月,面試到處碰壁,心態已經開始崩了
Research Report on the development trend and competitive strategy of the global high temperature label industry
Sqlachemy common operations
241. 为运算表达式设计优先级
sqlilabs less-8
Research Report on the development trend and competitive strategy of the global commercial glassware industry