当前位置:网站首页>Transformers datacollatorwithpadding class

Transformers datacollatorwithpadding class

2022-06-26 14:32:00 Live up to your youth

Construction method

DataCollat​​orWithPadding(tokenizer:PreTrainedTokenizerBase
						padding:typing.Union[bool, str, transformers.utils.generic.PaddingStrategy] = True
						max_length : typing.Optional[int] = None
						pad_to_multiple_of : typing.Optional[int] = None
						return_tensors : str = 'pt ' )

stay transfomers in , Defined a DataCollator class , This class is used to package a single element of a dataset into a batch of data .DataCollatorWithPadding Class is DataCollator Class , This class will dynamically fill in the input data when packaging .

Parameters tokenizer Indicates the input word breaker . Parameters padding It can be for bool type ,True Indicates filling ,False Means not to fill ; It can also be a string , Indicates a population policy ,"longest" It means to fill according to the longest data in the input data ,"max_length" Indicates that it is filled to the parameter max_length Set the length ,“do_not_pad" Means not to fill . Parameters pad_to_multiple_of Represents a multiple of the filled data . Parameters return_tensors Indicates the data type returned , It can be for "pt”,pytorch data type ;“tf”,tensorflow data type ;“np”,"numpy" data type .

Examples of use

>>> import transformers
>>> import datasets
>>> dataset = datasets.load_dataset("glue", "cola", split="train")
>>> dataset = dataset.map(lambda data: tokenizer(data["sentence"],padding=True), batched=True)
>>> dataset
Dataset({
    features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 8551
})
>>> tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
>>> data_collator = transformers.DataCollatorWithPadding(tokenizer, 
						   								 padding="max_length",
						   								 max_length=12,
						   								 return_tensors="tf")
>>> dataset = dataset.to_tf_dataset(columns=["label", "input_ids"], batch_size=16, shuffle=False, collate_fn=data_collator)
>>> dataset
<PrefetchDataset element_spec={'input_ids': TensorSpec(shape=(None, None), dtype=tf.int64, name=None), 'attention_mask': TensorSpec(shape=(None, None), dtype=tf.int64, name=None), 'labels': TensorSpec(shape=(None,), dtype=tf.int64, name=None)}>
原网站

版权声明
本文为[Live up to your youth]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/177/202206261327159298.html