当前位置:网站首页>Realizing sequence annotation with transformers
Realizing sequence annotation with transformers
2022-06-26 08:27:00 【xuanningmeng】
utilize transformers Realize sequence annotation
Studying recently transformers This module , Use this module to realize sequence annotation .transformers The module can load most of the pre training models , Include bert-base-uncased,bert-base-chinese,hfl/chinese-roberta-wwm-ext,hfl/chinese-roberta-wwm-ext-large etc. . Recently, I have tried to load different pre training models for sequence annotation , Comparison effect . The results of different pre training models on the same data set will be updated later . The environment is tensorflow==2.4.0,transformers==4.2.0. This blog is mainly about transformers load roberta Serial annotation of model ,roberta The pre training model of is the model of Harbin Institute of technology hfl/chinese-roberta-wwm-ext. The main structure of the blog is as follows :
(1)tokenizer
(2) Model input characteristics
(3) model training
(4) Model to evaluate
Tokenizer
Use transformers Medium BertWordPieceTokenizer obtain roberta Of tokenizer.
def get_tokenizer(model_path):
vocab_file = os.path.join(model_path, "vocab.txt")
tokenizer = BertWordPieceTokenizer(vocab_file,
lowercase=True)
return tokenizer
Model input characteristics
utilize tokenzer.encoder() Get right text Of encoder, There are many ways to write here , Here's mine tokenzer Methods , Consider the influence of Korean :
def create_inputs_targets(sentences, tags, tag2id, max_len, tokenizer):
dataset_dict = {
"input_ids": [],
"token_type_ids": [],
"attention_mask": [],
"tags": []
}
for sentence, tag in zip(sentences, tags):
input_ids = []
target_tags = []
for idx, word in enumerate(sentence):
ids = tokenizer.encode(word, add_special_tokens=False)
input_ids.extend(ids.ids)
# This judgment ids Length will avoid many mistakes ,tokenizer Multiple values appear in , Corresponding label Add it up, too , For example, for a Korean token Multiple values will appear after
num_tokens = len(ids)
target_tags.extend([tag[idx]] * num_tokens)
# Pad truncate, Add before and after the sentence '[CLS]','[SEP]'
input_ids = input_ids[:max_len - 2]
target_tags = target_tags[:max_len - 2]
input_ids = [101] + input_ids + [102]
# here 'O' The corresponding is 16, Does this correspond to tag2id Medium [CLS][SEP]
target_tags = [tag2id['O']] + target_tags + [tag2id['O']]
token_type_ids = [0] * len(input_ids)
attention_mask = [1] * len(input_ids)
padding_len = max_len - len(input_ids)
# vocab in [PAD] The code of is 0
input_ids = input_ids + ([0] * padding_len)
attention_mask = attention_mask + ([0] * padding_len)
token_type_ids = token_type_ids + ([0] * padding_len)
# target Add a new one here label Yes should be corresponding to [SEP] perhaps [CLS], Or is it 'O'
# taget padding 'O'
target_tags = target_tags + ([tag2id['O']] * padding_len)
dataset_dict["input_ids"].append(input_ids)
dataset_dict["token_type_ids"].append(token_type_ids)
dataset_dict["attention_mask"].append(attention_mask)
dataset_dict["tags"].append(target_tags)
assert len(target_tags) == max_len, f'{
len(input_ids)}, {
len(target_tags)}'
for key in dataset_dict:
dataset_dict[key] = np.array(dataset_dict[key])
x = [
dataset_dict["input_ids"],
dataset_dict["token_type_ids"],
dataset_dict["attention_mask"],
]
y = dataset_dict["tags"]
return x, y
model training
Use here TFBertForTokenClassification load roberta Pre training model , Not recommended TFRobertaForTokenClassification. The code is as follows :
model = TFBertForTokenClassification.from_pretrained(args["pretrain_model_path"],
from_pt=True,
num_labels=len(list(tag2id.keys())))
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, epsilon=1e-08)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
model.summary()
model.fit(train_x,
train_y,
epochs=epoch,
verbose=1,
batch_size=batch_size,
validation_data=(dev_x, dev_y),
validation_batch_size=batch_size
)
The learning rate here is relatively small , The learning rate can be set according to the specific task parameters .
Model to evaluate
Use seqeval Module for model evaluation , Calculate the accuracy of the model respectively , Recall rate and f1 value .seqeval There is a characteristic : When calculating the evaluation index ,true_label and predcit_label It can be a real label . The code of model evaluation is as follows :
def model_evaluate_roberta(model, data, label, tag2id, batch_size, seq_len_list):
id2tag = {
value: key for key, value in tag2id.items()}
pred_logits = model.predict(data, batch_size=batch_size)[0]
# pred shape [batch_size, max_len]
preds = np.argmax(pred_logits, axis=2).tolist()
assert len(preds) == len(seq_len_list)
# get predcit label
predict_label = []
target_label = []
for i in range(len(preds)):
pred = preds[i][1:]
temp = []
true_label = label[i][:min(seq_len_list[i], len(pred))]
for j in range(min(seq_len_list[i], len(pred))):
temp.append(id2tag[pred[j]])
assert len(temp) == len(true_label)
target_label.append(true_label)
predict_label.append(temp)
# Calculation precision, recall, f1_score
precision = precision_score(target_label, predict_label, average="macro", zero_division=0)
recall = recall_score(target_label, predict_label, average="macro", zero_division=0)
f1 = f1_score(target_label, predict_label, average="macro", zero_division=0)
logger.info(classification_report(target_label, predict_label))
return precision, recall, f1
To study the transformers load roberta Do sequence tagging , There is a phenomenon that the model effect is very poor in the middle , It is found that there is something wrong with the model , There is also a learning rate set too large . Later supplement in msra Model results on datasets . If there is an error , Welcome to testify .
边栏推荐
- Jupyter的安装
- Double linked list -- tail interpolation construction (C language)
- Comparison version number [leetcode]
- Interview ES6
- Oracle database self study notes
- What is Qi certification Qi certification process
- STM32 encountered problems using encoder module (library function version)
- Quickly upload data sets and other files to Google colab ------ solve the problem of slow uploading colab files
- Microcontroller from entry to advanced
- 你为什么会浮躁
猜你喜欢

Recognize the interruption of 80s51

Go语言浅拷贝与深拷贝

The solution of installing opencv with setting in pycharm
![[postgraduate entrance examination: planning group] clarify the relationship among memory, main memory, CPU, etc](/img/c2/d1432ab6021ee9da310103cc42beb3.jpg)
[postgraduate entrance examination: planning group] clarify the relationship among memory, main memory, CPU, etc

(4) Independent key

FFmpeg音视频播放器实现

Learn signal integrity from zero (SIPI) - (1)

. eslintrc. JS configuration

Can the encrypted JS code and variable name be cracked and restored?

How to debug plug-ins using vs Code
随机推荐
鲸会务一站式智能会议系统帮助主办方实现数字化会议管理
批量执行SQL文件
Undefined symbols for architecture i386 is related to third-party compiled static libraries
Necessary protection ring for weak current detection
(1) Turn on the LED
STM32 project design: temperature, humidity and air quality alarm, sharing source code and PCB
(3) Dynamic digital tube
JWT in go
Undefined symbols for architecture i386与第三方编译的静态库有关
你为什么会浮躁
MySQL practice: 3 Table operation
SOC wireless charging scheme
Using transformers of hugging face to realize multi label text classification
And are two numbers of S
Analysis of internal circuit of operational amplifier
MySQL insert Chinese error
Timer code guide in optee
Read excel table and render with FileReader object
Teach you a few tricks: 30 "overbearing" warm words to coax girls, don't look regret!
Recyclerview item gets the current position according to the X and Y coordinates