当前位置:网站首页>Realizing sequence annotation with transformers
Realizing sequence annotation with transformers
2022-06-26 08:27:00 【xuanningmeng】
utilize transformers Realize sequence annotation
Studying recently transformers This module , Use this module to realize sequence annotation .transformers The module can load most of the pre training models , Include bert-base-uncased,bert-base-chinese,hfl/chinese-roberta-wwm-ext,hfl/chinese-roberta-wwm-ext-large etc. . Recently, I have tried to load different pre training models for sequence annotation , Comparison effect . The results of different pre training models on the same data set will be updated later . The environment is tensorflow==2.4.0,transformers==4.2.0. This blog is mainly about transformers load roberta Serial annotation of model ,roberta The pre training model of is the model of Harbin Institute of technology hfl/chinese-roberta-wwm-ext. The main structure of the blog is as follows :
(1)tokenizer
(2) Model input characteristics
(3) model training
(4) Model to evaluate
Tokenizer
Use transformers Medium BertWordPieceTokenizer obtain roberta Of tokenizer.
def get_tokenizer(model_path):
vocab_file = os.path.join(model_path, "vocab.txt")
tokenizer = BertWordPieceTokenizer(vocab_file,
lowercase=True)
return tokenizer
Model input characteristics
utilize tokenzer.encoder() Get right text Of encoder, There are many ways to write here , Here's mine tokenzer Methods , Consider the influence of Korean :
def create_inputs_targets(sentences, tags, tag2id, max_len, tokenizer):
dataset_dict = {
"input_ids": [],
"token_type_ids": [],
"attention_mask": [],
"tags": []
}
for sentence, tag in zip(sentences, tags):
input_ids = []
target_tags = []
for idx, word in enumerate(sentence):
ids = tokenizer.encode(word, add_special_tokens=False)
input_ids.extend(ids.ids)
# This judgment ids Length will avoid many mistakes ,tokenizer Multiple values appear in , Corresponding label Add it up, too , For example, for a Korean token Multiple values will appear after
num_tokens = len(ids)
target_tags.extend([tag[idx]] * num_tokens)
# Pad truncate, Add before and after the sentence '[CLS]','[SEP]'
input_ids = input_ids[:max_len - 2]
target_tags = target_tags[:max_len - 2]
input_ids = [101] + input_ids + [102]
# here 'O' The corresponding is 16, Does this correspond to tag2id Medium [CLS][SEP]
target_tags = [tag2id['O']] + target_tags + [tag2id['O']]
token_type_ids = [0] * len(input_ids)
attention_mask = [1] * len(input_ids)
padding_len = max_len - len(input_ids)
# vocab in [PAD] The code of is 0
input_ids = input_ids + ([0] * padding_len)
attention_mask = attention_mask + ([0] * padding_len)
token_type_ids = token_type_ids + ([0] * padding_len)
# target Add a new one here label Yes should be corresponding to [SEP] perhaps [CLS], Or is it 'O'
# taget padding 'O'
target_tags = target_tags + ([tag2id['O']] * padding_len)
dataset_dict["input_ids"].append(input_ids)
dataset_dict["token_type_ids"].append(token_type_ids)
dataset_dict["attention_mask"].append(attention_mask)
dataset_dict["tags"].append(target_tags)
assert len(target_tags) == max_len, f'{
len(input_ids)}, {
len(target_tags)}'
for key in dataset_dict:
dataset_dict[key] = np.array(dataset_dict[key])
x = [
dataset_dict["input_ids"],
dataset_dict["token_type_ids"],
dataset_dict["attention_mask"],
]
y = dataset_dict["tags"]
return x, y
model training
Use here TFBertForTokenClassification load roberta Pre training model , Not recommended TFRobertaForTokenClassification. The code is as follows :
model = TFBertForTokenClassification.from_pretrained(args["pretrain_model_path"],
from_pt=True,
num_labels=len(list(tag2id.keys())))
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, epsilon=1e-08)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
model.summary()
model.fit(train_x,
train_y,
epochs=epoch,
verbose=1,
batch_size=batch_size,
validation_data=(dev_x, dev_y),
validation_batch_size=batch_size
)
The learning rate here is relatively small , The learning rate can be set according to the specific task parameters .
Model to evaluate
Use seqeval Module for model evaluation , Calculate the accuracy of the model respectively , Recall rate and f1 value .seqeval There is a characteristic : When calculating the evaluation index ,true_label and predcit_label It can be a real label . The code of model evaluation is as follows :
def model_evaluate_roberta(model, data, label, tag2id, batch_size, seq_len_list):
id2tag = {
value: key for key, value in tag2id.items()}
pred_logits = model.predict(data, batch_size=batch_size)[0]
# pred shape [batch_size, max_len]
preds = np.argmax(pred_logits, axis=2).tolist()
assert len(preds) == len(seq_len_list)
# get predcit label
predict_label = []
target_label = []
for i in range(len(preds)):
pred = preds[i][1:]
temp = []
true_label = label[i][:min(seq_len_list[i], len(pred))]
for j in range(min(seq_len_list[i], len(pred))):
temp.append(id2tag[pred[j]])
assert len(temp) == len(true_label)
target_label.append(true_label)
predict_label.append(temp)
# Calculation precision, recall, f1_score
precision = precision_score(target_label, predict_label, average="macro", zero_division=0)
recall = recall_score(target_label, predict_label, average="macro", zero_division=0)
f1 = f1_score(target_label, predict_label, average="macro", zero_division=0)
logger.info(classification_report(target_label, predict_label))
return precision, recall, f1
To study the transformers load roberta Do sequence tagging , There is a phenomenon that the model effect is very poor in the middle , It is found that there is something wrong with the model , There is also a learning rate set too large . Later supplement in msra Model results on datasets . If there is an error , Welcome to testify .
边栏推荐
- Discrete device ~ diode triode
- Mapping '/var/mobile/Library/Caches/com. apple. keyboards/images/tmp. gcyBAl37' failed: 'Invalid argume
- Detailed explanation of SOC multi-core startup process
- Method of measuring ripple of switching power supply
- golang json unsupported value: NaN 处理
- Idea update
- I Summary Preface
- Win10 mysql-8.0.23-winx64 solution for forgetting MySQL password (detailed steps)
- Interview for postgraduate entrance examination of Baoyan University - machine learning
- STM32 encountered problems using encoder module (library function version)
猜你喜欢
随机推荐
Ora-12514: tns: the listener currently does not recognize the service requested in the connection descriptor
批量修改文件名
(2) Buzzer
Using transformers of hugging face to realize multi label text classification
Teach you a few tricks: 30 "overbearing" warm words to coax girls, don't look regret!
Understanding of closures
loading view时,后面所有东西屏蔽
golang json unsupported value: NaN 处理
opencv学习笔记三
STM32 project design: smart home system design based on stm32
MySQL practice: 3 Table operation
leetcode2022年度刷题分类型总结(十二)并查集
监听iPad键盘显示和隐藏事件
Discrete device ~ resistance capacitance
Batch modify file name
Can the encrypted JS code and variable name be cracked and restored?
optee中的timer代码导读
Using transformers of hugging face to realize named entity recognition
How to Use Instruments in Xcode
STM32 project design: smart door lock PCB and source code based on stm32f1 (4 unlocking methods)





![Comparison version number [leetcode]](/img/02/d1a1922c10e5360e511782b16690e1.jpg)


