当前位置:网站首页>Using transformers of hugging face to realize multi label text classification
Using transformers of hugging face to realize multi label text classification
2022-06-26 08:26:00 【xuanningmeng】
Multi label classification
Text classification is one of the basic tasks of naturallanguageprocessing . Most text categorization is multi - category , That is, the data has multiple labels . Multi label text will be encountered in actual work or project . I use hugging face Of Transformers Realize multi label text classification . Author's tensorflow Version is 2.4.0,transformers The version is 4.2.0
Data processing
utilize transformers Medium BertTokenizer On data Tokenizer. The code is as follows :
def get_model_data(data, labels, max_seq_len=128):
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese", do_lower_case=True)
dataset_dict = {
"input_ids": [],
"attention_mask": [],
"label": []
}
assert len(data) == len(labels)
for i in range(len(data)):
sentence = data[i]
input_ids = tokenizer.encode(
sentence, # Sentence to encode.
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
max_length=max_seq_len, # Truncate all sentences.
)
sentence_length = len(input_ids)
input_ids = pad_sequences([input_ids],
maxlen=max_seq_len,
dtype="long",
value=0,
truncating="post",
padding="post")
input_ids = input_ids.tolist()[0]
attention_mask = [1] * sentence_length + [0] * (max_seq_len - sentence_length)
dataset_dict["input_ids"].append(input_ids)
# dataset_dict["token_type_ids"].append(token_type_ids)
dataset_dict["attention_mask"].append(attention_mask)
dataset_dict["label"].append(labels[i])
for key in dataset_dict:
dataset_dict[key] = np.array(dataset_dict[key])
x = [
dataset_dict["input_ids"],
dataset_dict["attention_mask"],
]
y = dataset_dict["label"]
return x, y
Multi label classification model
utilize Transformers Build a multi label classification model . The activation function of multi label classification in the last layer of the model is sigmoid, The activation function of multiple categories is softmax. The loss function of multi label classification is BinaryCrossentropy. The code is as follows :
class BertMultiClassifier(object):
def __init__(self, bert_model_name, label_num):
self.label_num = label_num
self.bert_model_name = bert_model_name
def get_model(self):
bert = TFBertModel.from_pretrained(self.bert_model_name)
input_ids = Input(shape=(None,), dtype=tf.int32, name="input_ids")
attention_mask = Input(shape=(None,), dtype=tf.int32, name="attention_mask")
outputs = bert(input_ids, attention_mask=attention_mask)[1]
cla_outputs = Dense(self.label_num, activation='sigmoid')(outputs)
model = Model(
inputs=[input_ids, attention_mask],
outputs=[cla_outputs])
return model
def create_model(bert_model_name, label_nums):
model = BertMultiClassifier(bert_model_name, label_nums).get_model()
optimizer = tf.keras.optimizers.Adam(lr=1e-5)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=False)
model.compile(optimizer=optimizer, loss=loss_object,
metrics=['accuracy', tf.keras.metrics.Precision(),
tf.keras.metrics.Recall(),
tf.keras.metrics.AUC()]) # metrics=['accuracy']
return model
model training
utilize tensorflow The higher order in API keras Training models , Save model as h5, Save the model as pb Model . The training code is as follows :
model.fit(train_x, train_y, epochs=args["epoch"], verbose=1,
batch_size=args["batch_size"],
callbacks=callbacks,
validation_data=(val_x, val_y),
validation_batch_size=args["batch_size"])
model_path = os.path.join("./output/model/", "mulclassifition.h5")
model.save_weights(model_path)
tf.keras.models.save_model(model, args["pbmodel_path"], save_format="tf", overwrite=True)
Load model predictions
The general training model can be directly loaded to make prediction, and can also be used Tensorflow serving Deployment provision http service , The author introduces these two methods respectively . The code to directly load the model for prediction is as follows :
def predict(test_data, args, label_num):
# test_steps_per_epoch = len(test_data) // args["batch_size"]
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese", do_lower_case=True)
testdata = get_model_data(test_data, tokenizer, args["max_length"])
print("testdata: ", testdata)
model = create_model(args['bert_model_name'], label_num)
model.load_weights("./output/model/mulclassifition.h5")
pred_logits = model.predict(testdata, batch_size=args["batch_size"])
pred = np.where(pred_logits >= 0.5, 1, 0).tolist()
return pred
HTTP service
utilize Tensorflow serving and Flask Provide HTTP service . The code is as follows :
@app.route("/multiclassfier", methods=['POST'])
def multiclassifier_pred():
data_para = json.loads(request.get_data(), encoding="utf-8")
sentence = data_para["sent"]
print("sentence: ", sentence)
# get model input
test_x = get_model_data(sentence, tokenizer, 256)
input_ids = test_x[0]
attention_mask = test_x[1]
data = json.dumps({
"signature_name": "serving_default",
"inputs": {
"input_ids": input_ids,
"attention_mask": attention_mask}})
headers = {
"content-type": "application/json"}
result = requests.post("http://ip:port/v1/models/multiclass:predict", data=data, headers=headers)
if result.status_code == 200:
result = json.loads(result.text)
pred_logits = np.array(result["outputs"])
pred = np.where(pred_logits >= 0.5, 1, 0).tolist()
pred_encoder = label_encoder(pred, label)
return_result = {
"code": 200, "sent": sentence, "label": pred_encoder[0]}
return jsonify(return_result)
else:
return jsonify({
"code": result.status_code,
"message": traceback.format_exc()})
In code http://ip:port/v1/models/multiclass:predict yes tensorflow serving Load model to make prediction service . use docker Deploy tensorflow serving Deployment Services .
边栏推荐
- 51 MCU project design: Based on 51 MCU clock perpetual calendar
- Comparison between Apple Wireless charging scheme and 5W wireless charging scheme
- Interview for postgraduate entrance examination of Baoyan University - machine learning
- CodeBlocks integrated Objective-C development
- JS file message invalid character error
- swift 代码实现方法调用
- STM32 based d18s20 (one wire)
- 批量执行SQL文件
- Esp8266wifi module tutorial: punctual atom atk-esp8266 for network communication, single chip microcomputer and computer, single chip microcomputer and mobile phone to send data
- Understanding of closures
猜你喜欢

Crawler case 1: JS reversely obtains HD Wallpapers of minimalist Wallpapers

Use of jupyter notebook

Chapter VI (pointer)

Cause analysis of serial communication overshoot and method of termination

(vs2019 MFC connects to MySQL) make a simple login interface (detailed)

Oracle 19C local listener configuration error - no listener

73b2d wireless charging and receiving chip scheme

51 MCU project design: Based on 51 MCU clock perpetual calendar

STM32 project design: temperature, humidity and air quality alarm, sharing source code and PCB

(3) Dynamic digital tube
随机推荐
SOC的多核启动流程详解
Time functions supported in optee
Application of wireless charging receiving chip xs016 coffee mixing cup
Oracle database self study notes
Interview for postgraduate entrance examination of Baoyan University - machine learning
[postgraduate entrance examination: planning group] clarify the relationship among memory, main memory, CPU, etc
See which processes occupy specific ports and shut down
Bluebridge cup 1 introduction training Fibonacci series
h5 localStorage
The difference between setstoragesync and setstorage
教你几招:30句哄女孩的“霸道”温馨话,不看后悔!
Test method - decision table learning
The difference between push-pull circuit drive and totem pole drive
swift 代码实现方法调用
Handwritten instanceof underlying principle
RF filter
When loading view, everything behind is shielded
Interview ES6
Read excel table and render with FileReader object
STM32 project design: temperature, humidity and air quality alarm, sharing source code and PCB