import time import os import psutil import h5py import socket import argparse import numpy as np import multiprocessing from tqdm import tqdm from random import shuffle from transformers import AutoTokenizer from get_mask import PreTrainingDataset def get_raw_instance(document, max_sequence_length=512): """ 获取初步的训练实例,将整段按照max_sequence_length切分成多个部分,并以多个处理好的实例的形式返回。 :param document: 一整段 :param max_sequence_length: :return: a list. each element is a sequence of text """ # document = self.documents[index] max_sequence_length_allowed = max_sequence_length - 2 # document = [seq for seq in document if len(seq)= max_sequence_length_allowed: if len(curr_seq) > 0: result_list.append(curr_seq) curr_seq = [] result_list.append(document[sz_idx][ : max_sequence_length_allowed]) sz_idx += 1 else: result_list.append(curr_seq) curr_seq = [] # 对最后一个序列进行处理,如果太短的话,丢弃掉。 if len(curr_seq) > max_sequence_length_allowed / 2: # /2 result_list.append(curr_seq) # # 计算总共可以得到多少份 # num_instance=int(len(big_list)/max_sequence_length_allowed)+1 # print("num_instance:",num_instance) # # 切分成多份,添加到列表中 # result_list=[] # for j in range(num_instance): # index=j*max_sequence_length_allowed # end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1 # result_list.append(big_list[index:end_index]) return result_list def split_numpy_chunk(path, tokenizer, pretrain_data, host): documents = [] instances = [] s = time.time() with open(path, encoding='utf-8') as fd: document = [] for i, line in enumerate(tqdm(fd)): line = line.strip() # document = line # if len(document.split("")) <= 3: # continue if len(line ) > 0 and line[:2] == "]]": # This is end of document documents.append(document) document = [] elif len(line) >= 2: document.append(line) if len(document) > 0: documents.append(document) print('read_file ', time.time() - s) # documents = [x for x in documents if x] # print(len(documents)) # print(len(documents[0])) # print(documents[0][0:10]) from typing import List import multiprocessing ans = [] for docs in tqdm(documents): ans.append(pretrain_data.tokenize(docs)) print(time.time() - s) del documents instances = [] for a in tqdm(ans): raw_ins = get_raw_instance(a) instances.extend(raw_ins) del ans print('len instance', len(instances)) sen_num = len(instances) seq_len = 512 input_ids = np.zeros([sen_num, seq_len], dtype=np.int32) input_mask = np.zeros([sen_num, seq_len], dtype=np.int32) segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32) masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32) for index, ins in tqdm(enumerate(instances)): mask_dict = pretrain_data.create_training_instance(ins) input_ids[index] = mask_dict[0] input_mask[index] = mask_dict[1] segment_ids[index] = mask_dict[2] masked_lm_output[index] = mask_dict[3] with h5py.File(f'/output/{host}.h5', 'w') as hf: hf.create_dataset("input_ids", data=input_ids) hf.create_dataset("input_mask", data=input_ids) hf.create_dataset("segment_ids", data=segment_ids) hf.create_dataset("masked_lm_positions", data=masked_lm_output) del instances def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name): if os.path.exists(os.path.join(output_path, f'{file_name}.h5')): print(f'{file_name}.h5 exists') return documents = [] instances = [] s = time.time() with open(input_path, 'r', encoding='utf-8') as fd: document = [] for i, line in enumerate(tqdm(fd)): line = line.strip() if len(line ) > 0 and line[:2] == "]]": # This is end of document documents.append(document) document = [] elif len(line) >= 2: document.append(line) if len(document) > 0: documents.append(document) print(f'read_file cost {time.time() - s}, length is {len(documents)}') ans = [] s = time.time() pool = multiprocessing.Pool(worker) encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100) for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'): ans.append(res) pool.close() print((time.time() - s) / 60) del documents instances = [] for a in tqdm(ans, colour='MAGENTA'): raw_ins = get_raw_instance(a, max_sequence_length=seq_len) instances.extend(raw_ins) del ans print('len instance', len(instances)) new_instances = [] for _ in range(dupe_factor): for ins in instances: new_instances.append(ins) shuffle(new_instances) instances = new_instances print('after dupe_factor, len instance', len(instances)) sentence_num = len(instances) input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32) segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32) s = time.time() pool = multiprocessing.Pool(worker) encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32) for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'): input_ids[index] = mask_dict[0] input_mask[index] = mask_dict[1] segment_ids[index] = mask_dict[2] masked_lm_output[index] = mask_dict[3] pool.close() print((time.time() - s) / 60) with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf: hf.create_dataset("input_ids", data=input_ids) hf.create_dataset("input_mask", data=input_mask) hf.create_dataset("segment_ids", data=segment_ids) hf.create_dataset("masked_lm_positions", data=masked_lm_output) del instances if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer') parser.add_argument('--seq_len', type=int, default=512, help='sequence length') parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100') parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence') parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id') parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively') parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document') parser.add_argument('--worker', type=int, default=32, help='number of process') parser.add_argument('--server_num', type=int, default=10, help='number of servers') args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) pretrain_data = PreTrainingDataset(tokenizer, args.seq_len, args.backend, max_predictions_per_seq=args.max_predictions_per_seq) data_len = len(os.listdir(args.input_path)) for i in range(data_len): input_path = os.path.join(args.input_path, f'{i}.txt') if os.path.exists(input_path): start = time.time() print(f'process {input_path}') split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor, args.seq_len, i) end_ = time.time() print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) print(f'has cost {(end_ - start) / 60}') print('-' * 100) print('') # if you have multiple server, you can use code below or modify code to openmpi # host = int(socket.gethostname().split('GPU')[-1]) # for i in range(data_len // args.server_num + 1): # h = args.server_num * i + host - 1 # input_path = os.path.join(args.input_path, f'{h}.txt') # if os.path.exists(input_path): # start = time.time() # print(f'I am server {host}, process {input_path}') # split_numpy_chunk_pool(input_path, # args.output_path, # pretrain_data, # args.worker, # args.dupe_factor, # args.seq_len, # h) # end_ = time.time() # print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) # print(f'has cost {(end_ - start) / 60}') # print('-' * 100) # print('')