当前位置:网站首页>Bert Chinese classification model training + reasoning + deployment
Bert Chinese classification model training + reasoning + deployment
2022-06-12 06:18:00 【Caicaicaicaicaicai】
Article preview :
0. bert brief introduction
- BERT(Bidirectional Encoder Representation from Transformers) yes google-research stay 2018 year 10 A pre training model proposed in May , stay 11 Species difference NLP Created in the test SOTA performance , Become NLP Milestone model achievements in the history of development .
- This article starts with practice , Lead you to Bert Chinese text classification model training and reasoning tutorial .
1. bert structure
1. bert Chinese classification model training
bert Training is mainly divided into two stages : Pre training stage and Fine-tuning Stage .
What is pre training ?
BERT It's a pre training model , So what is pre training ? Give a simple introduction with examples
Suppose there is already A Training set , First use A Pre train the network , stay A Learn network parameters on task , Then save it for later use , When a new task comes B, Adopt the same network structure , Network parameters can be loaded during initialization A Learn good parameters , Other high-level parameters are initialized randomly , After use B Task training data to train the network , When the loaded parameters remain unchanged , be called "frozen", When the loaded parameters follow B The training of the task is constantly changing , be called “fine-tuning”, That is, better adjust the parameters to make them more suitable for the current B Mission
advantage : When tasks B When there is less training data , It's hard to train the network well , But I got it A Training parameters , Than just using B The training parameters are better
Preliminary training
The pre training phase and Word2Vec,ELMo And so on , It's based on some pre training tasks on a large data set . Pre training requires huge computing resources ,google Official pre training a language model , Need to be in 4 To 16 individual Cloud TPU Four days of training , Fortunately, ,google Many pre training models are officially open source , Including Chinese pre training model . majority NLP Researchers never need to pre - train their own models from scratch .
fine-tuning
Fine-tuning The phase is used for some downstream tasks , Fine tune on the basis of the pre training model , For example, text classification , Part of speech tagging , Q & a system, etc ,BERT Fine tuning can be done on different tasks without restructuring .
1 download bert Project code
https://github.com/google-research/bert
The code structure
- Preliminary training
In open source code , The entrance to pre training is at run_pretraining.py. - fine-tuning
The fine-tuning entry is for different tasks run_classifier.py and run_squad.py.
among run_classifier.py Training for text classification task .
and run_squad.py Training for reading comprehension tasks .
2 Download the Chinese pre training model
For Chinese ,google A small parameter is published BERT Pre training model .
- Model download address
https://github.com/google-research/bert/blob/master/multilingual.md
- Go to download page , choice :BERT-Base, Chinese Download .

When the download is complete , Unzip to and run_classifier.py Same level directory .
Model file description

bert_model.ckpt: Stored model variables
vocab.txt: Dictionary for Chinese text ,
bert_config.json: yes bert During the training , Some configuration parameters that can be adjusted .
3 Make Chinese training data set
Train the classification model of Chinese text data , It is necessary to write a program to process the training data , Make it pass in as required bert Model training , and BERT In the code processor Class is responsible for processing the input data of the model .
Take the classification task as an example , Introduce how to modify processor Class to run on its own data set fine-tune. stay run_classsifier.py We can see in the document ,google For some public datasets, some processor, Such as XnliProcessor,MnliProcessor,MrpcProcessor and ColaProcessor. This gives us a good example , Guide us how to write for our own data sets processor.
class kedataProcessor(DataProcessor):
"""Processor for the XNLI data set."""
def get_train_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
# The category returned here is the specific category of your classification
labelf = open(FLAGS.data_dir+'\label.csv', 'r', newline='', encoding = 'gb2312')
labeldata = csv.reader(labelf, delimiter="\t")
labelList=[]
for line in labeldata:
label=line[1]
labelList.append(label)
return labelList
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
# print('guid:',guid," text:",text_a,' label:',label)
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
Self defined processor You need to inherit DataProcessor, And reload to get label Of get_labels And get a single input get_train_examples,get_dev_examples and get_test_examples function . They will be respectively in main Functional FLAGS.do_train、FLAGS.do_eval and FLAGS.do_predict The stage is called .
The contents of these three functions are almost the same , The only difference is that you need to specify the address of each read file .
With get_train_examples For example , The function needs to return a InputExample Class composition list.InputExample Class is a very simple class , Only initialization functions , The parameters that need to be passed in guid Is used to distinguish each example Of , May, in accordance with the train-%d’%(i) The way to define .text_a It's a string ,text_b Is another string . After subsequent input processing (BERT Code already contains , You don't have to do it yourself ) text_a and text_b To combine into [CLS] text_a [SEP] text_b [SEP] In the form of . Last parameter label It is also in the form of string ,label The content of should be guaranteed to appear in get_labels Function return list in .
For example , Suppose we want to deal with a model that can judge sentence similarity , Now in data_dir There is a path named train.csv The input file of , If we now input the file in the following format csv form :
sure Um. , Yes , yes .
Job search status _ Found a job Uh , I have a job thank you .
Neutral Oh , Tell me .
Ask for the post address Well, that , Then you are working , Is it arranged nearby , I think there are many of you
sure Can you hear me .
Neutral ah , You said you said .
Intelligent assistant Of users are temporarily unable to answer your call , SMS notification please hang up and leave a voice message , Please press once for manual help , Please press zero
I didn't hear anything clearly What are you doing ?
Busy Oh , I'm on my way to the interview .
no Um. , We don't have it now , Excuse me? .
label.csv
0 dial the wrong number 226
1 I called 127
2 Refuse to call 177
3 Phone number acquisition 19
4 Ask yourself if you hear me 55
5 Non owner 285
6 no 4477
7 welfare _ other 15
8 welfare _ Accommodation 47
9 welfare _ Five social insurance and one housing fund 83
10 Position _ Age requirements 58
modify processor Dictionaries
Modified to complete processor after , Need to be in the original main Functional processor In the dictionary , Add modified processor class , You can specify to call this... In the running parameters processor.
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"kedata": kedataProcessor,
}
function run_classifier.py
And then you can run it directly run_classsifier.py Train the model . Some parameters need to be specified at runtime , A relatively complete operating parameter is as follows :
BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 # Global variables Download pre training bert Address
MY_DATASET=/path/to/kedata # Global variables Data set address
python run_classifier.py \
--task_name=kedata \ # Add... Yourself processor stay processors In the dictionary key name
--do_train=true \
--do_eval=true \
--dopredict=true \
--data_dir=$MY_DATASET \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=128 \ # Model parameters
--train_batch_size=32 \
--learning_rate=5e-5 \
--num_train_epochs=2.0 \
--output_dir=/tmp/selfsim_output/ # Model output path
2. bert Model reasoning
1.tensorflow Reasoning

- problem :
use Estimater.predict Always re model load Again , This kind of engineering business is useless . - Solution :
- Use python The generator , Let the program “ mistook ” There are many sequences that need to be predicted , This structure yield The form can be ;
- utilize tf.data.Dataset.from_generator, Load the generator , Declare the data structure and type ;
- utilize class Class instance variables self The overall situation , adopt self.inputs The data “ Feed to ” Generator internal , This ensures that the data “ everfount ”;
- Procedural requirements close The mechanism of , To ensure that the generator stops working .
from tokenization import FullTokenizer, validate_case_matches_checkpoint
from modeling import BertConfig
from run_classifier import model_fn_builder
import tensorflow as tf
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.run_config import RunConfig
from run_classifier import FLAGS,InputExample,convert_single_example
import csv
import os
import numpy as np
class model(object):
def __init__(self, labelfile,vocabfile,bert_config_file,init_checkpoint,ckptdir,max_seq_length):
self.label = self.loadLabel(labelfile)
self.closed = False
self.first_run = True
self.bert_config_file=bert_config_file
self.ckptdir=ckptdir
self.tokenizer = FullTokenizer(
vocab_file=vocabfile,
do_lower_case=True)
self.init_checkpoint = init_checkpoint
self.seq_length = max_seq_length
self.text = None
self.num_examples = None
self.predictions = None
self.estimator = self.get_estimator()
def loadLabel(sel, labelfile):
labelf = open(labelfile, 'r', newline='', encoding='gbk')
labeldata = csv.reader(labelf, delimiter="\t")
labelList = []
for line in labeldata:
label = line[1]
labelList.append(label)
return labelList
def get_estimator(self):
validate_case_matches_checkpoint(True, self.init_checkpoint)
bert_config = BertConfig.from_json_file(self.bert_config_file) # load bert Custom configuration
if FLAGS.max_seq_length > bert_config.max_position_embeddings: # Verify the accuracy of the configuration information
raise ValueError(
"Cannot use sequence length %d because the BERT pre_model "
"was only trained up to sequence length %d" %
(self.seq_length, bert_config.max_position_embeddings))
run_config = RunConfig(
model_dir=self.ckptdir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
session_config=None
)
model_fn = model_fn_builder( # Estimator function , Provide Estimator The use of model_fn, For internal use EstimatorSpec Built
bert_config=bert_config,
num_labels=len(self.label),
init_checkpoint=self.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=0,
num_warmup_steps=0,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu)
estimator = Estimator( # Instantiate estimator
model_fn=model_fn,
config=run_config,
warm_start_from=self.init_checkpoint # Add preheating
)
return estimator
def get_feature(self, index, text):
example = InputExample(f"text_{index}", text, None, self.label[0])
feature = convert_single_example(index, example, self.label, self.seq_length, self.tokenizer)
return feature.input_ids, feature.input_mask, feature.segment_ids, feature.label_id
def create_generator(self):
""" Build generator """
while not self.closed:
self.num_examples = len(self.text)
features = (self.get_feature(*f) for f in enumerate(self.text))
yield dict(zip(("input_ids", "input_mask", "segment_ids", "label_ids"), zip(*features)))
def input_fn_builder(self):
""" Used to forecast. Create forecast data separately , Not based on file data """
dataset = tf.data.Dataset.from_generator(
self.create_generator,
output_types={'input_ids': tf.int32,
'input_mask': tf.int32,
'segment_ids': tf.int32,
'label_ids': tf.int32},
output_shapes={
'label_ids': (None),
'input_ids': (None, None),
'input_mask': (None, None),
'segment_ids': (None, None)}
)
return dataset
def predict(self, text):
self.text = text
if self.first_run:
self.predictions = self.estimator.predict(
input_fn=self.input_fn_builder, yield_single_examples=True)
self.first_run = False
probabilities = next(self.predictions)
# Get maximum index
index = np.argmax(probabilities["probabilities"])
label = self.label[index]
# return [self.label[i] for i in probabilities["probabilities"].argmax(axis=1)]
return label
def close(self):
self.closed = True
pythonfile=os.path.realpath(os.path.realpath(__file__))
pardir=os.path.abspath(os.path.join(pythonfile,os.path.pardir))
labelfile=os.path.join(pardir,'ckpt/label.csv')
init_checkpoint=os.path.join(pardir,'chinese_L-12_H-768_A-12/bert_model.ckpt')
vocabfile=os.path.join(pardir,'chinese_L-12_H-768_A-12/vocab.txt')
bert_config_file=os.path.join(pardir,'chinese_L-12_H-768_A-12/bert_config.json')
ckptdir=os.path.join(pardir,'ckpt/')
max_seq_length=128
def getModel():
bert = model(labelfile,vocabfile,bert_config_file,init_checkpoint,ckptdir,max_seq_length)
bert.predict([""])
return bert
if __name__=="__main__":
bert=getModel()
for i in range(1000):
label=bert.predict([" dial the wrong number "])
2. onnxruntime Reasoning
ONNX Runtime It is a high-performance machine learning model reasoning engine . It is associated with PyTorch、TensorFlow And many other support ONNX Standard frameworks and tools are compatible .ONNX Runtime An open and extensible architecture is designed , Through the use of built-in graphics optimization and cross CPU、GPU And various hardware acceleration functions of edge devices , You can easily optimize and speed up reasoning .ONNX Runtime It can be easily inserted into your technology stack , Because it can be in Linux、Windows、Mac and Android To work on , And for Python、c#、c++、C and Java Provides convenient api.
To speed up bert Reasoning time of , Deploy to server , Edible onnxruntime Reasoning speeds up .
1. checkpoint Format conversion to saveModel Format

from tokenization import FullTokenizer, validate_case_matches_checkpoint
from modeling import BertConfig
from run_classifier import model_fn_builder
import tensorflow as tf
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.run_config import RunConfig
from run_classifier import FLAGS
import csv
class Fast(object):
def __init__(self, labelfile,vocabfile,bert_config_file,init_checkpoint,ckptdir):
self.label = self.loadLabel(labelfile)
self.closed = False
self.first_run = True
self.bert_config_file=bert_config_file
self.ckptdir=ckptdir
self.tokenizer = FullTokenizer(
vocab_file=vocabfile,
do_lower_case=True)
self.init_checkpoint = init_checkpoint
# self.seq_length = FLAGS.max_seq_length
self.seq_length = 128
self.text = None
self.num_examples = None
self.predictions = None
self.estimator = self.get_estimator()
def loadLabel(sel, labelfile):
labelf = open(labelfile, 'r', newline='', encoding='gbk')
labeldata = csv.reader(labelf, delimiter="\t")
labelList = []
# for i in range(60):
# labelList.append(i)
for line in labeldata:
label = line[1]
labelList.append(label)
return labelList
def get_estimator(self):
validate_case_matches_checkpoint(True, self.init_checkpoint)
print("FLAGS.bert_config_file:",FLAGS.bert_config_file)
bert_config = BertConfig.from_json_file(self.bert_config_file) # load bert Custom configuration
if FLAGS.max_seq_length > bert_config.max_position_embeddings: # Verify the accuracy of the configuration information
raise ValueError(
"Cannot use sequence length %d because the BERT pre_model "
"was only trained up to sequence length %d" %
(self.seq_length, bert_config.max_position_embeddings))
print("FLAGS.save_checkpoints_steps:",FLAGS.save_checkpoints_steps)
run_config = RunConfig(
model_dir=self.ckptdir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
session_config=None
)
model_fn = model_fn_builder( # Estimator function , Provide Estimator The use of model_fn, For internal use EstimatorSpec Built
bert_config=bert_config,
num_labels=len(self.label),
init_checkpoint=self.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=0,
num_warmup_steps=0,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu)
print("model_fn:",model_fn)
estimator = Estimator( # Instantiate estimator
model_fn=model_fn,
config=run_config,
warm_start_from=self.init_checkpoint # Add preheating
)
print("estimator.params:",estimator.params)
print("estimator:",estimator)
return estimator
def serving_input_fn(self):
receiver_tensors = {
'input_ids': tf.compat.v1.placeholder(dtype=tf.int64, shape=[None, self.seq_length], name='input_ids'),
'input_mask': tf.compat.v1.placeholder(dtype=tf.int64, shape=[None, self.seq_length], name='input_mask'),
'segment_ids': tf.compat.v1.placeholder(dtype=tf.int64, shape=[None, self.seq_length], name='segment_ids'),
'label_ids': tf.compat.v1.placeholder(dtype=tf.int64, shape=[None], name="label_ids")
}
return tf.estimator.export.ServingInputReceiver(features= receiver_tensors ,receiver_tensors=receiver_tensors)
def transModel(self):
self.estimator.export_saved_model('./savemodel',self.serving_input_fn)
labelfile='./ckpt/label.csv'
init_checkpoint='./chinese_L-12_H-768_A-12/bert_model.ckpt'
vocabfile='./chinese_L-12_H-768_A-12/vocab.txt'
bert_config_file='./chinese_L-12_H-768_A-12/bert_config.json'
ckptdir='./ckpt/'
model = Fast(labelfile,vocabfile,bert_config_file,init_checkpoint,ckptdir)
model.transModel()
2. saveModel Format conversion to onnx Format
import os
pbdir="1631247382"
onnxname="model.onnx"
cmdstr="python -m tf2onnx.convert --saved-model ./savemodel/{pbdir} --output ./onnx/{onnxname}".format(pbdir=pbdir,onnxname=onnxname)
os.system(cmdstr)
3. Use onnxruntime Reasoning onnx Format model

import onnxruntime as ort
from tokenization import FullTokenizer
from run_classifier import convert_single_example
from run_classifier import InputExample
import time
import numpy as np
import csv
import os
class model(object):
def __init__(self, vocab_file,labelfile,modelfile,max_seq_length):
self.closed = False
self.first_run = True
self.tokenizer = FullTokenizer(
vocab_file=vocab_file,
do_lower_case=True)
self.seq_length = max_seq_length
self.label = self.loadLabel(labelfile)
so = ort.SessionOptions()
#so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
# so.execution_mode = ort.ExecutionMode.ORT_PARALLEL
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
#self.model = ort.InferenceSession(modelfile,sess_options=so,providers=['CPUExecutionProvider'])
self.model = ort.InferenceSession(modelfile,sess_options=so,providers=['CUDAExecutionProvider'])
#self.model.set_providers(['CUDAExecutionProvider'], [{'device_id': device_id,"gpu_mem_limit" : 3 * 1024 * 1024 * 1024}])
# self.model.set_providers(['CUDAExecutionProvider'], [{'device_id': device_id }])
def loadLabel(sel,labelfile):
labelf = open(labelfile, 'r', newline='',encoding='gbk')
labeldata = csv.reader(labelf, delimiter="\t")
labelList = []
#for i in range(60):
# labelList.append(i)
for line in labeldata:
label = line[1]
labelList.append(label)
return labelList
def get_feature(self, index, text):
example = InputExample(f"text_{index}", text, None, self.label[0])
feature = convert_single_example(index, example, self.label, self.seq_length, self.tokenizer)
return feature.input_ids, feature.input_mask, feature.segment_ids, feature.label_id
def predict(self, text):
dataList=[]
input_ids_L=[]
input_mask_L=[]
segment_ids_L=[]
label_ids_L=[]
for i in range(len(text)):
input_ids,input_mask,segment_ids,label_ids=self.get_feature(i,text[i])
input_ids_L.append(input_ids)
input_mask_L.append(input_mask)
segment_ids_L.append(segment_ids)
label_ids_L.append(label_ids)
data = {
"input_ids:0": np.array(input_ids_L, dtype=np.int64),
"input_mask:0": np.array(input_mask_L, dtype=np.int64),
"label_ids:0": np.array(label_ids_L, dtype=np.int64),
"segment_ids:0": np.array(segment_ids_L, dtype=np.int64)
}
dataList.append(data)
result = self.model.run(output_names=["loss/Softmax:0"], input_feed=data)
label_l=[]
for i in range(len(text)):
# Get maximum index
maxProbabilities=max(result[0][i])
index=np.argmax(result[0][i])
label=self.label[index]
item={"label":label,"score":maxProbabilities}
label_l.append(label)
return label_l
pythonfile=os.path.realpath(os.path.realpath(__file__))
pardir=os.path.abspath(os.path.join(pythonfile,os.path.pardir))
datadir=os.path.join(pardir,'zhaopin')
labelfile=os.path.join(datadir,'label.csv')
modelfile=os.path.join(datadir,'model.onnx')
vocabfile=os.path.join(pardir,'vocab.txt')
max_seq_length=128
def getModel():
bert = model(vocabfile,labelfile,modelfile ,max_seq_length)
return bert
if __name__=="__main__":
bert=getModel()
for i in range(1000):
time1=time.time()
bert.predict([" I've already called "])
边栏推荐
- Textcnn (MR dataset - emotion classification)
- Zip and Items() difference
- (UE4 4.27) add globalshder to the plug-in
- LeetCode个人题解(剑指offer3-5)3.数组中重复的数字,4.二维数组中的查找,5.替换空格
- Leetcode-717. 1-bit and 2-bit characters (O (1) solution)
- Pytorch implementation of regression model
- Leetcode-2048. Next larger numerical balance
- Single channel picture reading
- Guns framework multi data source configuration without modifying the configuration file
- Introduction to the method of diligently searching for the alliance procedure
猜你喜欢

Bulk Rename Utility

Unity implements smooth interpolation

Leetcode-1260. 2D mesh migration

The vs 2019 community version Microsoft account cannot be logged in and activated offline

Poisson disk sampling for procedural placement

Nodemon cannot load the file c:\users\administrator\appdata\roaming\npm\nodemon PS1, because script execution is prohibited in this system

(UE4 4.27) customize globalshader

Redis queue

Sensor bringup 中的一些问题总结

n次贝塞尔曲线
随机推荐
Why doesn't the database use binary tree, red black tree, B tree and hash table? Instead, a b+ tree is used
EBook editing and deleting
肝了一個月的 DDD,一文帶你掌握
On the normalization of camera rotation interpolation
Research Report on truffle fungus industry - market status analysis and development prospect forecast
Zip and Items() difference
Why do I object so [1.01 to the power of 365 and 0.99 to the power of 365]
Using hidden Markov model to mark part of speech
cv2.fillPoly coco annotator segment坐标转化为mask图像
SQLite cross compile dynamic library
Idea common configuration
Un mois de DDD hépatique.
RMB classification II
LeetCode个人题解(剑指offer3-5)3.数组中重复的数字,4.二维数组中的查找,5.替换空格
. Net core - pass Net core will Net to cross platform
Unity3d script captures a sub area from the screen and saves it as texture2d, which is used to save pictures and maps
Three years of sharpening a sword: insight into the R & D efficiency of ant financial services
[reinstall system] 01 system startup USB flash disk production
Why is the union index the leftmost matching principle?
N-degree Bessel curve