mirror of https://github.com/hpcaitech/ColossalAI
276 lines
10 KiB
Python
276 lines
10 KiB
Python
import time
|
||
import os
|
||
import psutil
|
||
import h5py
|
||
import socket
|
||
import argparse
|
||
import numpy as np
|
||
import multiprocessing
|
||
from tqdm import tqdm
|
||
from random import shuffle
|
||
from transformers import AutoTokenizer
|
||
from get_mask import PreTrainingDataset
|
||
|
||
|
||
def get_raw_instance(document, max_sequence_length=512):
|
||
|
||
"""
|
||
获取初步的训练实例,将整段按照max_sequence_length切分成多个部分,并以多个处理好的实例的形式返回。
|
||
:param document: 一整段
|
||
:param max_sequence_length:
|
||
:return: a list. each element is a sequence of text
|
||
"""
|
||
# document = self.documents[index]
|
||
max_sequence_length_allowed = max_sequence_length - 2
|
||
# document = [seq for seq in document if len(seq)<max_sequence_length_allowed]
|
||
sizes = [len(seq) for seq in document]
|
||
|
||
result_list = []
|
||
curr_seq = [] # 当前处理的序列
|
||
sz_idx = 0
|
||
while sz_idx < len(sizes):
|
||
# 当前句子加上新的句子,如果长度小于最大限制,则合并当前句子和新句子;否则即超过了最大限制,那么做为一个新的序列加到目标列表中
|
||
|
||
if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0:
|
||
curr_seq += document[sz_idx]
|
||
sz_idx += 1
|
||
elif sizes[sz_idx] >= max_sequence_length_allowed:
|
||
if len(curr_seq) > 0:
|
||
result_list.append(curr_seq)
|
||
curr_seq = []
|
||
result_list.append(document[sz_idx][ : max_sequence_length_allowed])
|
||
sz_idx += 1
|
||
else:
|
||
result_list.append(curr_seq)
|
||
curr_seq = []
|
||
# 对最后一个序列进行处理,如果太短的话,丢弃掉。
|
||
if len(curr_seq) > max_sequence_length_allowed / 2: # /2
|
||
result_list.append(curr_seq)
|
||
|
||
# # 计算总共可以得到多少份
|
||
# num_instance=int(len(big_list)/max_sequence_length_allowed)+1
|
||
# print("num_instance:",num_instance)
|
||
# # 切分成多份,添加到列表中
|
||
# result_list=[]
|
||
# for j in range(num_instance):
|
||
# index=j*max_sequence_length_allowed
|
||
# end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1
|
||
# result_list.append(big_list[index:end_index])
|
||
return result_list
|
||
|
||
|
||
def split_numpy_chunk(path, tokenizer, pretrain_data, host):
|
||
|
||
documents = []
|
||
instances = []
|
||
|
||
s = time.time()
|
||
with open(path, encoding='utf-8') as fd:
|
||
document = []
|
||
for i, line in enumerate(tqdm(fd)):
|
||
line = line.strip()
|
||
# document = line
|
||
# if len(document.split("<sep>")) <= 3:
|
||
# continue
|
||
if len(line
|
||
) > 0 and line[:2] == "]]": # This is end of document
|
||
documents.append(document)
|
||
document = []
|
||
elif len(line) >= 2:
|
||
document.append(line)
|
||
if len(document) > 0:
|
||
documents.append(document)
|
||
print('read_file ', time.time() - s)
|
||
|
||
# documents = [x for x in documents if x]
|
||
# print(len(documents))
|
||
# print(len(documents[0]))
|
||
# print(documents[0][0:10])
|
||
from typing import List
|
||
import multiprocessing
|
||
|
||
ans = []
|
||
for docs in tqdm(documents):
|
||
ans.append(pretrain_data.tokenize(docs))
|
||
print(time.time() - s)
|
||
del documents
|
||
|
||
instances = []
|
||
for a in tqdm(ans):
|
||
raw_ins = get_raw_instance(a)
|
||
instances.extend(raw_ins)
|
||
del ans
|
||
|
||
print('len instance', len(instances))
|
||
|
||
sen_num = len(instances)
|
||
seq_len = 512
|
||
input_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
|
||
input_mask = np.zeros([sen_num, seq_len], dtype=np.int32)
|
||
segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
|
||
masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32)
|
||
|
||
for index, ins in tqdm(enumerate(instances)):
|
||
mask_dict = pretrain_data.create_training_instance(ins)
|
||
input_ids[index] = mask_dict[0]
|
||
input_mask[index] = mask_dict[1]
|
||
segment_ids[index] = mask_dict[2]
|
||
masked_lm_output[index] = mask_dict[3]
|
||
|
||
with h5py.File(f'/output/{host}.h5', 'w') as hf:
|
||
hf.create_dataset("input_ids", data=input_ids)
|
||
hf.create_dataset("input_mask", data=input_ids)
|
||
hf.create_dataset("segment_ids", data=segment_ids)
|
||
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
|
||
|
||
del instances
|
||
|
||
|
||
def split_numpy_chunk_pool(input_path,
|
||
output_path,
|
||
pretrain_data,
|
||
worker,
|
||
dupe_factor,
|
||
seq_len,
|
||
file_name):
|
||
|
||
if os.path.exists(os.path.join(output_path, f'{file_name}.h5')):
|
||
print(f'{file_name}.h5 exists')
|
||
return
|
||
|
||
documents = []
|
||
instances = []
|
||
|
||
s = time.time()
|
||
with open(input_path, 'r', encoding='utf-8') as fd:
|
||
document = []
|
||
for i, line in enumerate(tqdm(fd)):
|
||
line = line.strip()
|
||
if len(line
|
||
) > 0 and line[:2] == "]]": # This is end of document
|
||
documents.append(document)
|
||
document = []
|
||
elif len(line) >= 2:
|
||
document.append(line)
|
||
if len(document) > 0:
|
||
documents.append(document)
|
||
print(f'read_file cost {time.time() - s}, length is {len(documents)}')
|
||
|
||
ans = []
|
||
s = time.time()
|
||
pool = multiprocessing.Pool(worker)
|
||
encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100)
|
||
for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'):
|
||
ans.append(res)
|
||
pool.close()
|
||
print((time.time() - s) / 60)
|
||
del documents
|
||
|
||
instances = []
|
||
for a in tqdm(ans, colour='MAGENTA'):
|
||
raw_ins = get_raw_instance(a, max_sequence_length=seq_len)
|
||
instances.extend(raw_ins)
|
||
del ans
|
||
|
||
print('len instance', len(instances))
|
||
|
||
new_instances = []
|
||
for _ in range(dupe_factor):
|
||
for ins in instances:
|
||
new_instances.append(ins)
|
||
|
||
shuffle(new_instances)
|
||
instances = new_instances
|
||
print('after dupe_factor, len instance', len(instances))
|
||
|
||
sentence_num = len(instances)
|
||
input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
|
||
input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32)
|
||
segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
|
||
masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32)
|
||
|
||
s = time.time()
|
||
pool = multiprocessing.Pool(worker)
|
||
encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32)
|
||
for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'):
|
||
input_ids[index] = mask_dict[0]
|
||
input_mask[index] = mask_dict[1]
|
||
segment_ids[index] = mask_dict[2]
|
||
masked_lm_output[index] = mask_dict[3]
|
||
pool.close()
|
||
print((time.time() - s) / 60)
|
||
|
||
with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf:
|
||
hf.create_dataset("input_ids", data=input_ids)
|
||
hf.create_dataset("input_mask", data=input_mask)
|
||
hf.create_dataset("segment_ids", data=segment_ids)
|
||
hf.create_dataset("masked_lm_positions", data=masked_lm_output)
|
||
|
||
del instances
|
||
|
||
|
||
if __name__ == '__main__':
|
||
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer')
|
||
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
|
||
parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100')
|
||
parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence')
|
||
parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id')
|
||
parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively')
|
||
parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document')
|
||
parser.add_argument('--worker', type=int, default=32, help='number of process')
|
||
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
|
||
args = parser.parse_args()
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
||
pretrain_data = PreTrainingDataset(tokenizer,
|
||
args.seq_len,
|
||
args.backend,
|
||
max_predictions_per_seq=args.max_predictions_per_seq)
|
||
|
||
|
||
data_len = len(os.listdir(args.input_path))
|
||
|
||
for i in range(data_len):
|
||
input_path = os.path.join(args.input_path, f'{i}.txt')
|
||
if os.path.exists(input_path):
|
||
start = time.time()
|
||
print(f'process {input_path}')
|
||
split_numpy_chunk_pool(input_path,
|
||
args.output_path,
|
||
pretrain_data,
|
||
args.worker,
|
||
args.dupe_factor,
|
||
args.seq_len,
|
||
i)
|
||
end_ = time.time()
|
||
print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
|
||
print(f'has cost {(end_ - start) / 60}')
|
||
print('-' * 100)
|
||
print('')
|
||
|
||
# if you have multiple server, you can use code below or modify code to openmpi
|
||
|
||
# host = int(socket.gethostname().split('GPU')[-1])
|
||
# for i in range(data_len // args.server_num + 1):
|
||
# h = args.server_num * i + host - 1
|
||
# input_path = os.path.join(args.input_path, f'{h}.txt')
|
||
# if os.path.exists(input_path):
|
||
# start = time.time()
|
||
# print(f'I am server {host}, process {input_path}')
|
||
# split_numpy_chunk_pool(input_path,
|
||
# args.output_path,
|
||
# pretrain_data,
|
||
# args.worker,
|
||
# args.dupe_factor,
|
||
# args.seq_len,
|
||
# h)
|
||
# end_ = time.time()
|
||
# print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
|
||
# print(f'has cost {(end_ - start) / 60}')
|
||
# print('-' * 100)
|
||
# print('')
|
||
|
||
|