当前位置:网站首页>Using transformers of hugging face to realize text classification
Using transformers of hugging face to realize text classification
2022-06-26 08:26:00 【xuanningmeng】
Text classification
Text classification task is very common in practical work , Generally, it is multi classification and multi label classification . See the blog for the content of multi tag classification https://blog.csdn.net/weixin_42223207/article/details/115036283. This article is to use hugging face Of Transformers Realize text classification , The framework adopted is tensorflow==2.4.0. The content of this paper is as follows :
- Data processing
- Model
- model training
- Model to predict
Data processing
use BertTokenizer For words Tokenizer, The code is as follows `
def create_inputs_targets(sentences, labels, max_len, tokenizer):
dataset_dict = {
"input_ids": [],
"attention_mask": [],
"labels": []
}
assert len(sentences) == len(labels)
for i in range(len(sentences)):
input_ids = []
for idx, word in enumerate(sentences[i]):
ids = tokenizer.encode(word, add_special_tokens=False)
input_ids.extend(ids.ids)
# Pad truncate, Add before and after the sentence '[CLS]','[SEP]'
input_ids = input_ids[:max_len - 2]
input_ids = [101] + input_ids + [102]
# here 'O' The corresponding is 16, Does this correspond to tag2id Medium [CLS][SEP]
attention_mask = [1] * len(input_ids)
padding_len = max_len - len(input_ids)
# vocab in [PAD] The code of is 0
input_ids = input_ids + ([0] * padding_len)
attention_mask = attention_mask + ([0] * padding_len)
dataset_dict["input_ids"].append(input_ids)
dataset_dict["attention_mask"].append(attention_mask)
dataset_dict["labels"].append(labels[i])
for key in dataset_dict:
dataset_dict[key] = np.array(dataset_dict[key])
x = [
dataset_dict["input_ids"],
dataset_dict["attention_mask"],
]
y = dataset_dict["labels"]
return x, y
Here we take bert tokenizer Medium input_ids and attention_mask.
Model
use Bert Model development fine tuning, The code is as follows :
class BertTextClassifier(object):
def __init__(self, bert_model_name, label_num):
self.label_num = label_num
self.bert_model_name = bert_model_name
def get_model(self):
bert = TFBertModel.from_pretrained(self.bert_model_name)
input_ids = keras.Input(shape=(None,), dtype=tf.int32, name="input_ids")
attention_mask = keras.Input(shape=(None,), dtype=tf.int32, name="attention_mask")
outputs = bert(input_ids, attention_mask=attention_mask)[1]
cla_outputs = layers.Dense(self.label_num, activation='softmax')(outputs)
model = keras.Model(
inputs=[input_ids, attention_mask],
outputs=[cla_outputs])
return model
def create_model(bert_model_name, label_nums):
model = BertTextClassifier(bert_model_name, label_nums).get_model()
optimizer = tf.keras.optimizers.Adam(lr=1e-5)
model.compile(optimizer=optimizer, loss='categorical_crossentropy',
metrics=['accuracy', tf.keras.metrics.Precision(),
tf.keras.metrics.Recall(),
tf.keras.metrics.AUC()]) # metrics=['accuracy']
return model
model training
Here the tensorflow2.x The higher order in API keras Model training . The code is as follows :
model = create_model(args["bert_model_name"], len(tag2id))
# model.summary()
model.fit(train_x,
train_y,
epochs=epoch,
verbose=1,
batch_size=batch_size,
validation_data=(dev_x, dev_y),
validation_batch_size=batch_size
) # , validation_split=0.1
# model save
model_path = os.path.join(args["output_path"], "classification_model.h5")
model.save_weights(model_path, overwrite=True)
# save pb model
tf.keras.models.save_model(model, args["pb_path"],
save_format="tf",
overwrite=True)
The results of training on Sogou dataset are as follows :
precision recall f1-score support
sports 1.00 1.00 1.00 209
health 0.94 0.98 0.96 180
military 0.99 0.99 0.99 208
education 0.98 0.94 0.96 197
automobile 0.98 0.99 0.99 202
accuracy 0.98 996
macro avg 0.98 0.98 0.98 996
weighted avg 0.98 0.98 0.98 996
Model to predict
Process the data into the format of model input , The box Tokenizer Get the data input_ids and attention_mask Characteristics of . The code is as follows :
def create_infer_inputs(sentences, max_len, tokenizer):
dataset_dict = {
"input_ids": [],
"attention_mask": [],
}
for i in range(len(sentences)):
input_ids = []
for idx, word in enumerate(sentences[i]):
ids = tokenizer.encode(word, add_special_tokens=False)
input_ids.extend(ids.ids)
# Pad truncate, Add before and after the sentence '[CLS]','[SEP]'
input_ids = input_ids[:max_len - 2]
input_ids = [101] + input_ids + [102]
# here 'O' The corresponding is 16, Does this correspond to tag2id Medium [CLS][SEP]
attention_mask = [1] * len(input_ids)
padding_len = max_len - len(input_ids)
# vocab in [PAD] The code of is 0
input_ids = input_ids + ([0] * padding_len)
attention_mask = attention_mask + ([0] * padding_len)
dataset_dict["input_ids"].append(input_ids)
dataset_dict["attention_mask"].append(attention_mask)
for key in dataset_dict:
dataset_dict[key] = np.array(dataset_dict[key])
x = [
dataset_dict["input_ids"],
dataset_dict["attention_mask"],
]
return x
Here the Flask Service implementation model prediction . The code is as follows :
@app.route("/classification", methods=['POST'])
def classification_predict():
data = json.loads(request.get_data(), encoding="utf-8")
sentence = data["context"]
url = data["url"]
input_ids, attention_mask = create_infer_inputs(sentence, max_len, tokenizer)
print("input_ids: ", input_ids)
print("attention_mask: ", attention_mask)
data = json.dumps({
"signature_name": "serving_default",
"inputs": {
"input_ids": input_ids,
"attention_mask": attention_mask}})
headers = {
"content-type": "application/json"}
result = requests.post(url, data=data, headers=headers)
print("result: ", result)
if result.status_code == 200:
result = json.loads(result.text)
logits = np.array(result["outputs"])
pred = np.argmax(logits, axis=1).tolist()
pred_label = id2tag[pred[0]]
print(pred_label)
return_result = {
"code": 200,
"context": sentence,
"label": pred_label}
return jsonify(return_result)
else:
return_result = {
"code": 200,
"context": sentence,
"label": None}
return jsonify(return_result)
among url Yes, it is docker+Tensorflow serving Deploy model services . If there is a problem , Welcome to correct .
边栏推荐
- I want to open a stock account at a discount. How do I do it? Is it safe to open a mobile account?
- Use intent to shuttle between activities -- use implicit intent
- JWT in go
- MySQL insert Chinese error
- Delete dictionary from list
- Monitor iPad Keyboard Display and hide events
- Interview for postgraduate entrance examination of Baoyan University - machine learning
- Method of measuring ripple of switching power supply
- swift 代码实现方法调用
- Common uniapp configurations
猜你喜欢

STM32 project design: an e-reader making tutorial based on stm32f4

Oracle 19C local listener configuration error - no listener

Interview for postgraduate entrance examination of Baoyan University - machine learning
GHUnit: Unit Testing Objective-C for the iPhone

Use of jupyter notebook

What is Qi certification Qi certification process

Idea auto Guide

The difference between push-pull circuit drive and totem pole drive

Test method - decision table learning

Jupyter的安装
随机推荐
教你几招:30句哄女孩的“霸道”温馨话,不看后悔!
Comparison between Apple Wireless charging scheme and 5W wireless charging scheme
[postgraduate entrance examination planning group] conversion between signed and unsigned numbers
JMeter performance testing - Basic Concepts
MySQL practice: 1 Common database commands
批量执行SQL文件
How to Use Instruments in Xcode
Oracle 19C local listener configuration error - no listener
Uniapp scrolling load (one page, multiple lists)
你为什么会浮躁
1GHz active probe DIY
Interview for postgraduate entrance examination of Baoyan University - machine learning
Wechat applet beginner level chapter
Uniapp uses uviewui
Uniapp wechat withdrawal (packaged as app)
I want to open a stock account at a discount. How do I do it? Is it safe to open a mobile account?
Undefined symbols for architecture i386 is related to third-party compiled static libraries
Go语言浅拷贝与深拷贝
51 single chip microcomputer project design: schematic diagram of timed pet feeding system (LCD 1602, timed alarm clock, key timing) Protues, KEIL, DXP
When loading view, everything behind is shielded