You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/examples/community/roberta/preprocessing/get_mask.py

265 lines
9.4 KiB

import collections
import logging
import random
import jieba
jieba.setLogLevel(logging.CRITICAL)
import re
import mask
import numpy as np
PAD = 0
MaskedLMInstance = collections.namedtuple("MaskedLMInstance", ["index", "label"])
def map_to_numpy(data):
return np.asarray(data)
class PreTrainingDataset:
def __init__(
self,
tokenizer,
max_seq_length,
backend="python",
max_predictions_per_seq: int = 80,
do_whole_word_mask: bool = True,
):
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.masked_lm_prob = 0.15
self.backend = backend
self.do_whole_word_mask = do_whole_word_mask
self.max_predictions_per_seq = max_predictions_per_seq
self.vocab_words = list(tokenizer.vocab.keys())
self.rec = re.compile("[\u4E00-\u9FA5]")
self.whole_rec = re.compile("##[\u4E00-\u9FA5]")
self.mlm_p = 0.15
self.mlm_mask_p = 0.8
self.mlm_tamper_p = 0.05
self.mlm_maintain_p = 0.1
def tokenize(self, doc):
temp = []
for d in doc:
temp.append(self.tokenizer.tokenize(d))
return temp
def create_training_instance(self, instance):
is_next = 1
raw_text_list = self.get_new_segment(instance)
tokens_a = raw_text_list
assert len(tokens_a) == len(instance)
# tokens_a, tokens_b, is_next = instance.get_values()
# print(f'is_next label:{is_next}')
# Create mapper
tokens = []
original_tokens = []
segment_ids = []
tokens.append("[CLS]")
original_tokens.append("[CLS]")
segment_ids.append(0)
for index, token in enumerate(tokens_a):
tokens.append(token)
original_tokens.append(instance[index])
segment_ids.append(0)
tokens.append("[SEP]")
original_tokens.append("[SEP]")
segment_ids.append(0)
# for token in tokens_b:
# tokens.append(token)
# segment_ids.append(1)
# tokens.append("[SEP]")
# segment_ids.append(1)
# Get Masked LM predictions
if self.backend == "c++":
output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(
tokens,
original_tokens,
self.vocab_words,
self.tokenizer.vocab,
self.max_predictions_per_seq,
self.masked_lm_prob,
)
elif self.backend == "python":
output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens)
# Convert to Ids
input_ids = self.tokenizer.convert_tokens_to_ids(output_tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < self.max_seq_length:
input_ids.append(PAD)
segment_ids.append(PAD)
input_mask.append(PAD)
masked_lm_output.append(-1)
return [
map_to_numpy(input_ids),
map_to_numpy(input_mask),
map_to_numpy(segment_ids),
map_to_numpy(masked_lm_output),
map_to_numpy([is_next]),
]
def create_masked_lm_predictions(self, tokens):
cand_indexes = []
for i, token in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
# cand_indexes.append(i)
random.shuffle(cand_indexes)
output_tokens = list(tokens)
num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indexes:
continue
covered_indexes.add(index)
masked_token = None
# 80% mask
if random.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% Keep Original
if random.random() < 0.5:
masked_token = tokens[index]
# 10% replace w/ random word
else:
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLMInstance(index=index, label=tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens)
for p in masked_lms:
masked_lm_output[p.index] = self.tokenizer.vocab[p.label]
return (output_tokens, masked_lm_output)
def get_new_segment(self, segment):
"""
Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word.
:param segment: a sentence
"""
seq_cws = jieba.lcut("".join(segment))
seq_cws_dict = {x: 1 for x in seq_cws}
new_segment = []
i = 0
while i < len(segment):
if len(self.rec.findall(segment[i])) == 0:
new_segment.append(segment[i])
i += 1
continue
has_add = False
for length in range(3, 0, -1):
if i + length > len(segment):
continue
if "".join(segment[i : i + length]) in seq_cws_dict:
new_segment.append(segment[i])
for l in range(1, length):
new_segment.append("##" + segment[i + l])
i += length
has_add = True
break
if not has_add:
new_segment.append(segment[i])
i += 1
return new_segment
def create_whole_masked_lm_predictions(self, tokens):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for i, token in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##"
num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if random.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if random.random() < 0.5:
masked_token = (
tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index]
) # 去掉"##"
# 10% of the time, replace with random word
else:
masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(
MaskedLMInstance(
index=index,
label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index],
)
)
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_output = [-1] * len(output_tokens)
for p in masked_lms:
masked_lm_output[p.index] = self.tokenizer.vocab[p.label]
return (output_tokens, masked_lm_output)