当前位置:网站首页>tensorflow2-savedmodel convert to pb(frozen_graph)
tensorflow2-savedmodel convert to pb(frozen_graph)
2022-07-01 14:41:00 【Hula Hula hey】
1.frozen graph:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
frozen_out_path = '/data1/gyx/QR/SSD_Tensorflow2.0-master/convert/pb'
frozen_graph_filename = "frozen_graph_ssd_multi.pb"
saved_model_path= '/data1/gyx/QR/SSD_Tensorflow2.0-master/result/weight/ssd_multi/pb/'
# model = tf.keras.models.load_model(saved_model_path)
# model.summary()
# images=tf.random.uniform((1, 300, 300, 3))
# print(model.predict(images)[0].shape) #(1, 8732, 5)
# print(model.predict(images)[1].shape) #(1, 8732, 4)
# Define the input format
img = tf.random.uniform((1, 300, 300, 3))
# Load model
network = tf.keras.models.load_model(saved_model_path)
# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: network(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(img.shape, img.dtype)) # (1, 300, 300, 3) <dtype: 'uint8'>
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ") #[<tf.Tensor 'x:0' shape=(1, 300, 300, 3) dtype=float32>]
print(frozen_func.inputs)
print("Frozen model outputs: ") #[<tf.Tensor 'Identity:0' shape=(1, 8732, 5) dtype=float32>, <tf.Tensor 'Identity_1:0' shape=(1, 8732, 4) dtype=float32>]
print(frozen_func.outputs)
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir=frozen_out_path,
name=frozen_graph_filename,
as_text=False)
# import netron
# netron.start(os.path.join(frozen_out_path,frozen_graph_filename))
2. predict pb
import cv2
import tensorflow.compat.v1 as tf
import numpy as np
import os
import cv2
import copy
import config as c
from utils.aug_utils import color_normalize
from utils.anchor_utils import generate_anchors, from_offset_to_box,from_offset_to_box_nms
from utils.eval_utils import show_box,show_multibox
tf.disable_v2_behavior()
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
path='/data1/gyx/QR/SSD_Tensorflow2.0-master/convert/pb/frozen_graph_ssd_multi.pb'
image_path = '/data1/gyx/QR/SSD_Tensorflow2.0-master/test_pic/20210330101245.bmp'
img_cv2 = cv2.imread(image_path)
height, width, _ = np.shape(img_cv2)
#### The first image normalization method ####
# (1, 3, 300, 300)
# blob = cv2.dnn.blobFromImage(img_cv2,
# scalefactor=1.0 / 255,
# size=(300, 300),
# mean=(0, 0, 0),
# swapRB=False,
# crop=False)
# blob = cv2.dnn.blobFromImage(img_cv2,
# scalefactor=1.0 / 255,
# size=(300, 300),
# mean=(103.939/ 255, 116.779/ 255, 123.68/ 255),
# swapRB=False,
# crop=False)
# (1, 300, 300, 3)
# input_image = np.transpose(blob, (0,2,3,1))
#### The second image normalization method ( According to the training model format ) ####
input_image = np.array([color_normalize(cv2.resize(copy.copy(img_cv2), tuple(c.input_shape[:2])))], dtype=np.float32)
print(input_image.shape)
with tf.gfile.FastGFile(path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
out = sess.run([sess.graph.get_tensor_by_name('Identity:0'),
sess.graph.get_tensor_by_name('Identity_1:0')
],feed_dict={'x:0': input_image})
# out={'detection_classes':out[0],'detection_boxes':out[1]}
# print(out)
anchors = generate_anchors()
cls_pred, loc_pred = out[0],out[1]
print(cls_pred.shape, loc_pred.shape)
# print(type(loc_pred)) #<class 'numpy.ndarray'>
# cls_pred = tf.convert_to_tensor(np.array(cls_pred), dtype='float32')
# loc_pred = tf.convert_to_tensor(np.array(loc_pred), dtype='float32')
# print(cls_pred)
# print(loc_pred)
boxes, scores, labels = from_offset_to_box(loc_pred[0], cls_pred[0], anchors,
anchor_belongs_to_one_class=True, score_threshold=0.1)
print(boxes, scores, labels)
# print(boxes)
boxes_new=[]
score_new=[]
label_new=[]
for box, score, label in zip(boxes, scores, labels):
box[0] = box[0] / c.input_shape[1] * width # left
box[1] = box[1] / c.input_shape[0] * height # top
box[2] = box[2] / c.input_shape[1] * width # right
box[3] = box[3] / c.input_shape[0] * height # bottom
print('image: {}\nclass: {}\nconfidence: {:.4f}\n'.format(image_path, c.class_list[label], score))
boxes_new.append(box)
score_new.append(score)
label_new.append(c.class_list[label])
show_multibox(img_cv2, boxes_new,score_new,label_new)
边栏推荐
- Research Report on the development trend and competitive strategy of the global display filter industry
- Is it reasonable and safe for securities companies to open accounts for 10000 free securities? How to say
- Research Report on development trend and competitive strategy of global consumer glassware industry
- MIT团队使用图神经网络,加速无定形聚合物电解质筛选,促进下一代锂电池技术开发
- 【阶段人生总结】放弃考研,参与到工作中,已经顺利毕业了,昨天刚领毕业证
- [leetcode 324] 摆动排序 II 思维+排序
- How can we protect our passwords?
- One of the data Lake series | you must love to read the history of minimalist data platforms, from data warehouse, data lake to Lake warehouse
- SQLAchemy 常用操作
- 2022-2-15 learning xiangniuke project - Section 4 business management
猜你喜欢

So programmers make so much money doing private work? It's really delicious

sqlilabs less13

Use the right scene, get twice the result with half the effort! Full introduction to the window query function and usage scenarios of tdengine

Fundamentals of C language

MIT团队使用图神经网络,加速无定形聚合物电解质筛选,促进下一代锂电池技术开发

That hard-working student failed the college entrance examination... Don't panic! You have another chance to counter attack!
![[Verilog quick start of Niuke series] ~ multi function data processor, calculate the difference between two numbers, use generate... For statement to simplify the code, and use sub modules to realize](/img/30/aea4ae24f418eb971bca77a1d46bef.png)
[Verilog quick start of Niuke series] ~ multi function data processor, calculate the difference between two numbers, use generate... For statement to simplify the code, and use sub modules to realize

一波三折,终于找到src漏洞挖掘的方法了【建议收藏】

How can we protect our passwords?

Oracle-数据库对象的使用
随机推荐
Pat 1065 a+b and C (64bit) (20 points) (16 points)
Salesforce、约翰霍普金斯、哥大 | ProGen2: 探索蛋白语言模型的边界
【15. 区间合并】
NPDP能给产品经理带来什么价值?你都知道了吗?
Provincial election + noi Part XI others
Research Report on the development trend and competitive strategy of the global ultrasonic scalpel system industry
"National defense seven sons" funding soared, with Tsinghua reaching 36.2 billion yuan, ranking second with 10.1 billion yuan. The 2022 budget of national colleges and universities was made public
How can we protect our passwords?
qt捕获界面为图片或label显示
Don't want to knock the code? Here comes the chance
Provincial election + noi Part IX game theory
[stage life summary] I gave up the postgraduate entrance examination and participated in the work. I have successfully graduated and just received my graduation certificate yesterday
深度合作 | 涛思数据携手长虹佳华为中国区客户提供 TDengine 强大企业级产品与完善服务保障
Error-tf.function-decorated function tried to create variables on non-first call
Use the right scene, get twice the result with half the effort! Full introduction to the window query function and usage scenarios of tdengine
Provincial election + noi Part VIII fraction theory
241. Design priorities for operational expressions
Is the futures company found on Baidu safe? How do futures companies determine the regularity
Research Report on the development trend and competitive strategy of the global axis measurement system industry
Advanced C language