ColossalAI/examples/community/roberta/preprocessing/tokenize_mask.py

268 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import argparse
import multiprocessing
import os
import socket
import time
from random import shuffle
import h5py
import numpy as np
import psutil
from get_mask import PreTrainingDataset
from tqdm import tqdm
from transformers import AutoTokenizer
def get_raw_instance(document, max_sequence_length=512):
"""
Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances.
:param document: document
:param max_sequence_length:
:return: a list. each element is a sequence of text
"""
# document = self.documents[index]
max_sequence_length_allowed = max_sequence_length - 2
# document = [seq for seq in document if len(seq)<max_sequence_length_allowed]
sizes = [len(seq) for seq in document]
result_list = []
curr_seq = []
sz_idx = 0
while sz_idx < len(sizes):
if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0:
curr_seq += document[sz_idx]
sz_idx += 1
elif sizes[sz_idx] >= max_sequence_length_allowed:
if len(curr_seq) > 0:
result_list.append(curr_seq)
curr_seq = []
result_list.append(document[sz_idx][:max_sequence_length_allowed])
sz_idx += 1
else:
result_list.append(curr_seq)
curr_seq = []
if len(curr_seq) > max_sequence_length_allowed / 2: # /2
result_list.append(curr_seq)
# num_instance=int(len(big_list)/max_sequence_length_allowed)+1
# print("num_instance:",num_instance)
# result_list=[]
# for j in range(num_instance):
# index=j*max_sequence_length_allowed
# end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1
# result_list.append(big_list[index:end_index])
return result_list
def split_numpy_chunk(path, tokenizer, pretrain_data, host):
documents = []
instances = []
s = time.time()
with open(path, encoding='utf-8') as fd:
document = []
for i, line in enumerate(tqdm(fd)):
line = line.strip()
# document = line
# if len(document.split("<sep>")) <= 3:
# continue
if len(line) > 0 and line[:2] == "]]": # This is end of document
documents.append(document)
document = []
elif len(line) >= 2:
document.append(line)
if len(document) > 0:
documents.append(document)
print('read_file ', time.time() - s)
# documents = [x for x in documents if x]
# print(len(documents))
# print(len(documents[0]))
# print(documents[0][0:10])
import multiprocessing
from typing import List
ans = []
for docs in tqdm(documents):
ans.append(pretrain_data.tokenize(docs))
print(time.time() - s)
del documents
instances = []
for a in tqdm(ans):
raw_ins = get_raw_instance(a)
instances.extend(raw_ins)
del ans
print('len instance', len(instances))
sen_num = len(instances)
seq_len = 512
input_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
input_mask = np.zeros([sen_num, seq_len], dtype=np.int32)
segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32)
for index, ins in tqdm(enumerate(instances)):
mask_dict = pretrain_data.create_training_instance(ins)
input_ids[index] = mask_dict[0]
input_mask[index] = mask_dict[1]
segment_ids[index] = mask_dict[2]
masked_lm_output[index] = mask_dict[3]
with h5py.File(f'/output/{host}.h5', 'w') as hf:
hf.create_dataset("input_ids", data=input_ids)
hf.create_dataset("input_mask", data=input_ids)
hf.create_dataset("segment_ids", data=segment_ids)
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
del instances
def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name):
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
print(f'{file_name}.h5 exists')
return
documents = []
instances = []
s = time.time()
with open(input_path, 'r', encoding='utf-8') as fd:
document = []
for i, line in enumerate(tqdm(fd)):
line = line.strip()
if len(line) > 0 and line[:2] == "]]": # This is end of document
documents.append(document)
document = []
elif len(line) >= 2:
document.append(line)
if len(document) > 0:
documents.append(document)
print(f'read_file cost {time.time() - s}, length is {len(documents)}')
ans = []
s = time.time()
pool = multiprocessing.Pool(worker)
encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100)
for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'):
ans.append(res)
pool.close()
print((time.time() - s) / 60)
del documents
instances = []
for a in tqdm(ans, colour='MAGENTA'):
raw_ins = get_raw_instance(a, max_sequence_length=seq_len)
instances.extend(raw_ins)
del ans
print('len instance', len(instances))
new_instances = []
for _ in range(dupe_factor):
for ins in instances:
new_instances.append(ins)
shuffle(new_instances)
instances = new_instances
print('after dupe_factor, len instance', len(instances))
sentence_num = len(instances)
input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32)
segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32)
s = time.time()
pool = multiprocessing.Pool(worker)
encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32)
for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'):
input_ids[index] = mask_dict[0]
input_mask[index] = mask_dict[1]
segment_ids[index] = mask_dict[2]
masked_lm_output[index] = mask_dict[3]
pool.close()
print((time.time() - s) / 60)
with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf:
hf.create_dataset("input_ids", data=input_ids)
hf.create_dataset("input_mask", data=input_mask)
hf.create_dataset("segment_ids", data=segment_ids)
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
del instances
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
parser.add_argument('--max_predictions_per_seq',
type=int,
default=80,
help='number of shards, e.g., 10, 50, or 100')
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
parser.add_argument('--backend',
type=str,
default='python',
help='backend of mask token, python, c++, numpy respectively')
parser.add_argument(
'--dupe_factor',
type=int,
default=1,
help='specifies how many times the preprocessor repeats to create the input from the same article/document')
parser.add_argument('--worker', type=int, default=32, help='number of process')
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
pretrain_data = PreTrainingDataset(tokenizer,
args.seq_len,
args.backend,
max_predictions_per_seq=args.max_predictions_per_seq)
data_len = len(os.listdir(args.input_path))
for i in range(data_len):
input_path = os.path.join(args.input_path, f'{i}.txt')
if os.path.exists(input_path):
start = time.time()
print(f'process {input_path}')
split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor,
args.seq_len, i)
end_ = time.time()
print(u'memory%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
print(f'has cost {(end_ - start) / 60}')
print('-' * 100)
print('')
# if you have multiple server, you can use code below or modify code to openmpi
# host = int(socket.gethostname().split('GPU')[-1])
# for i in range(data_len // args.server_num + 1):
# h = args.server_num * i + host - 1
# input_path = os.path.join(args.input_path, f'{h}.txt')
# if os.path.exists(input_path):
# start = time.time()
# print(f'I am server {host}, process {input_path}')
# split_numpy_chunk_pool(input_path,
# args.output_path,
# pretrain_data,
# args.worker,
# args.dupe_factor,
# args.seq_len,
# h)
# end_ = time.time()
# print(u'memory%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
# print(f'has cost {(end_ - start) / 60}')
# print('-' * 100)
# print('')