当前位置:网站首页>《Attention-ocr-Chinese-Version-mas # ter》代碼運行邏輯
《Attention-ocr-Chinese-Version-mas # ter》代碼運行邏輯
2022-06-09 03:55:00 【陳壯實的編程生活】
文章目錄
- 1. 運行邏輯
- 2. 數據處理的走向
- **Step1: 按寬度分成四份,輸入:[batch_size, H, W, channels], 輸出:4個[batch_size, H, W/4, channels]**
- **Step2:將每個視圖帶入到InceptionV3中進行卷積,然後進行組合。輸入:4個[batch_size, H, W/4, channels],返回:[batch_size, H, W, N], N錶示特征數**
- **Step3: 比特置編碼,輸入[batch_size, H, W, N], 返回[batch_size, H, W, N+H+W]**
- **Step4: RNN decoder with Attention. 輸入:[batch_size, H, W, N+H+W], 輸出:[batch_size, seq_length, num_char_classes]**
- Step5: 預測字符。輸入:[batch_size, seq_length, num_char_classes],輸出:predicted_chars, chars_log_prob, predicted_scores
- Step6: 預測文本
1. 運行邏輯
2. 數據處理的走向
從tfrecord數據中獲取到的數據:
images: [batch_size, height, width, channels]
labels_one_hot: [batch_size, seq_length, num_char_class], 如[32, 37, 5642]
Step1: 按寬度分成四份,輸入:[batch_size, H, W, channels], 輸出:4個[batch_size, H, W/4, channels]
views = tf.split(
value=images, num_or_size_splits=self._params.num_views, axis=2) # 按視圖切分, 如原來是 20*30*40,若tf.split(my_tensor, 2, 0),則返回兩個 10*30*40的小張量
logging.debug('Views=%d single view: %s', len(views), views[0])
因為原來的代碼的一幅圖片有4個視圖,橫向排列,,所以這裏是安裝寬度,以視圖數量進行等比例切分。
Step2:將每個視圖帶入到InceptionV3中進行卷積,然後進行組合。輸入:4個[batch_size, H, W/4, channels],返回:[batch_size, H, W, N], N錶示特征數
nets = [
self.conv_tower_fn(v, is_training, reuse=(i != 0)) # con_tower_fn: 使用InceptionV3進行卷積,返回[batch_size,OH,OW,N],N錶示特征數
for i, v in enumerate(views)
]
操作如圖中紅框部分:
Step3: 比特置編碼,輸入[batch_size, H, W, N], 返回[batch_size, H, W, N+H+W]
比特置編碼是這邊論文的核心,我花了大量的時間去進行理解,理解後發現非常簡單。示意圖如下:

對應論文中的示意圖:
其編碼順序:
(1)對圖像每個像素的比特置(高-寬)進行編碼;
(2)編碼格式為one-hot編碼,如:寬有兩個比特置,則one-hot比特置占兩比特;高有三個比特置,則占三比特。
(3)如圖中的 像素1, 其比特置為(0, 0 ), 則按高編碼為1,0;按寬編碼為1,0,0,所以其按比特置編碼為10100;則在末尾追加10100;又如14的比特置為(1,2),則按高編碼01,按寬編碼001,則在其末尾追加01001。
思維擴展:個人的一些想法
既然可以對空間比特置進行編碼,那麼如果我們的數據是一些時序相關的數據,我們是不是可以將時序按照一定規則進行編碼,然後也將時序數據帶入到模型中,然後看是否能够提高准確度呢?
也實現了原文中關於比特置編碼的代碼,且做了一個小實例助理解,代碼如下:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import collections
import logging
import tensorflow as tf
import tf_slim as slim
from tensorflow.python.platform import flags
from tf_slim import model_analyzer
import data_provider
from PIL import Image
import matplotlib .pyplot as plt
import numpy as np
# 上面包含了一些無用的依賴包,可以根據下面的代碼進行删除
if __name__ == '__main__':
data = [
[
[
[1, 2, 3, 4],
[4, 5, 6, 7],
[8, 9, 10, 11]
],
[
[7, 8, 9, 10],
[10, 11, 12, 13],
[14, 15, 16, 17]
]
],
[
[
[13, 14, 15, 16],
[16, 17, 17, 18],
[19, 20, 21, 22]
],
[
[19, 20, 21, 19],
[22, 23, 24, 20],
[21, 22, 23, 24]
]
]
]
net = tf.constant(data, tf.float32)
print("net.shape = ", net.shape)
batch_size, h, w, _ = net.shape.as_list()
x, y = tf.meshgrid(tf.range(w), tf.range(h))
with tf.Session() as sess:
print("sess.run(x_tensor);{}".format(sess.run(x)))
print("sess.run(y_tensor);{}".format(sess.run(y)))
w_loc = slim.one_hot_encoding(x, num_classes=w)
h_loc = slim.one_hot_encoding(y, num_classes=h)
with tf.Session() as sess:
print("sess.run(w_loc_tensor);{}".format(sess.run(w_loc)))
print("sess.run(h_loc_tensor);{}".format(sess.run(h_loc)))
loc = tf.concat([h_loc, w_loc], 2)
loc = tf.tile(tf.expand_dims(loc, 0), [batch_size, 1, 1, 1])
res = tf.concat([net, loc], 3)
print("res.shape = ", res.shape)
with tf.Session() as sess:
print("sess.run(loc_tensor);{}".format(sess.run(loc)))
print("sess.run(res_tensor);{}".format(sess.run(res)))
print("完成")
Step4: RNN decoder with Attention. 輸入:[batch_size, H, W, N+H+W], 輸出:[batch_size, seq_length, num_char_classes]
到這裏後,就主要是調用的庫函數了。如下:
(1)sequence_layers.py
lstm_cell = tf.contrib.rnn.LSTMCell(
self._mparams.num_lstm_units, # 256
use_peepholes=False, # 默認False,True錶示啟用Peephole連接。peephole是指門層也會接受細胞狀態的輸入,也就是說在基本的LSTM的基礎上,在每一個門層的輸入時加入細胞狀態的輸入。
cell_clip=self._mparams.lstm_state_clip_value, # 10,是否在輸出前對cell狀態按照給定值進行截斷處理。
state_is_tuple=True, # 如果為True, 接受的和返回的狀態是一個(c, h)的二元組,其中c為細胞當前狀態,h為當前時間段的輸出的同時
initializer=orthogonal_initializer) # (可選) 權重和映射矩陣的初始化器。
lstm_outputs, _ = self.unroll_cell(
decoder_inputs=decoder_inputs,
initial_state=lstm_cell.zero_state(self._batch_size, tf.float32),
loop_function=self.get_input,
cell=lstm_cell)
(2)sequence_layers.py
def unroll_cell(self, decoder_inputs, initial_state, loop_function, cell):
return tf.contrib.legacy_seq2seq.attention_decoder(
decoder_inputs=decoder_inputs,
initial_state=initial_state,
attention_states=self._net,
cell=cell,
loop_function=self.get_input)
關於Attention以及LSTM的計算,可以詳見
Self-Attention
RNN
LSTM的一些解釋
注:
正交初始化對於RNN很重要:
Step5: 預測字符。輸入:[batch_size, seq_length, num_char_classes],輸出:predicted_chars, chars_log_prob, predicted_scores
(1)model.py
# 預測
predicted_chars, chars_log_prob, predicted_scores = ( # predicted_chars: 預測字符,形狀為[batch_size x seq_length]的int32張量;
self.char_predictions(chars_logit)) # chars_log_prob: 所有字符的對數概率,形狀為[batch_size, seq_length, num_char_classes]的浮點張量;
# predicted_scores: 字符的相應置信分數,形狀為 [batch_size x seq_length]的浮點張量。
(2)model.py
通過softmax進行預測。
# 返回預測字符的置信度得分(softmax值)。
def char_predictions(self, chars_logit):
"""Returns confidence scores (softmax values) for predicted characters. Args: chars_logit: chars logits, a tensor with shape [batch_size x seq_length x num_char_classes] Returns: A tuple (ids, log_prob, scores), where: ids - predicted characters, a int32 tensor with shape [batch_size x seq_length]; 預測字符,形狀為[batch_size x seq_length]的int32張量; log_prob - a log probability of all characters, a float tensor with shape [batch_size, seq_length, num_char_classes]; 所有字符的對數概率,形狀為[batch_size, seq_length, num_char_classes]的浮點張量; scores - corresponding confidence scores for characters, a float tensor with shape [batch_size x seq_length]. 字符的相應置信分數,形狀為 [batch_size x seq_length]的浮點張量。 """
log_prob = utils.logits_to_log_prob(chars_logit)
ids = tf.to_int32(tf.argmax(log_prob, axis=2), name='predicted_chars')
mask = tf.cast(
slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool)
all_scores = tf.nn.softmax(chars_logit)
selected_scores = tf.boolean_mask(all_scores, mask, name='char_scores')
scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length))
return ids, log_prob, scores
Step6: 預測文本
(1)model.py
if self._charset:
character_mapper = CharsetMapper(self._charset)
predicted_text = character_mapper.get_text(predicted_chars) # 返回與預測字符對應的文本
else:
predicted_text = tf.constant([])
(2)model.py
# 返回文本
def get_text(self, ids):
"""Returns a string corresponding to a sequence of character ids. Args: ids: a tensor with shape [batch_size, max_sequence_length] """
return tf.reduce_join(
self.table.lookup(tf.to_int64(ids)), reduction_indices=1)
至此,這篇文章算是完全明白了運行邏輯
边栏推荐
- Method of accessing LAN by external network and how virtual machines installed by VM access each other in LAN
- Today in history: kubernetes open source version released; Worldofwarcraft landed in China; The inventor of packet switching network was born
- Dapr 1.7 UNIX domain socket he's here
- 故障分析 | DDL 导致的 Xtrabackup 备份失败
- [leetcode] day 48 - 1037 Effective boomerang
- 技术分享 | 调整 max-write-buffer-size 优化 pika 性能10倍的案例
- Analysis of constant pool related problems
- JVM内存查看及设置思路
- Kubernetes binary installation (v1.20.16) (IV) deployment master
- Affichage de barres de couleur, de caractères et d'images VGA basé sur FPGA
猜你喜欢

基于PyQt5完成的抠图功能-程序实现

How to write a blueprint for the data center
![How to write test cases for [e-commerce] test coupons in 2022?](/img/f4/7b06b7319b79a886e3296fb70c0989.png)
How to write test cases for [e-commerce] test coupons in 2022?

SQL audit | here are the most commonly used SQL development rules for mysql/oracle

Getting started with Maui custom drawing

Today in history: kubernetes open source version released; Worldofwarcraft landed in China; The inventor of packet switching network was born

Memory surge problem location

On June 12, MSG enterprise bank focused on Langfang AI enterprises to promote the integrated development of "double chain"

基于PyQt5完成的图转文功能

Word+ regular expression = = quickly batch add caption (nanny level text)
随机推荐
解决MYSQL库or表进行删除等操作时出现卡死等情况
Octopus network louis:gamefi as an application chain has gradually become a consensus
Geometric application problems
Some bugs of unity physics engine caused by squeezing
The remote connection to the Huawei ECS installation database is rejected
MAUI 自定义绘图入门
外网访问局域网方法和VM安装的虚拟机如何在局域网内互相访问
What is the value of Yuan universe virtual real estate?
How to use superset to seamlessly connect with MRS for self-service analysis
印尼Widya Robotics携手华为云,让建筑工地安全看得见
资料:kube-flannel和kubernetes-dashboard创建命令和配置文件内容
服务器注册使用
Don't hand in the paper in advance!!! Someone got 13 points in the last 35 seconds!!! (best wishes to all candidates)
人才缺口50万以上,平均薪资20K?网络安全,测试员的下一个风口~
Some thoughts on callback
[examination in May] Oracle OCP 19C passed
专家,如何快速转型管理者?
网页设计期末大作业-景点旅游网站(含导航栏,轮播图,样式精美)
六个C语言期末大作业-KTV选歌、个人收支管理、职工资源管理、班级学生档案管理、产品信息管理、图书馆管理系统
Experts, how to quickly transform managers?