ColossalAI/examples/language/roberta/preprocessing/sentence_split.py

164 lines
5.7 KiB
Python
Raw Normal View History

import multiprocessing
import os
import re
from tqdm import tqdm
from typing import List
import json
import time
import argparse
import functools
def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
"""
Args:
document:
flag: Type:str, "all" 中英文标点分句"zh" 中文标点分句"en" 英文标点分句
limit: 默认单句最大长度为510个字符
Returns: Type:list
"""
sent_list = []
try:
if flag == "zh":
document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document) # 单字符断句符
document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document) # 特殊引号
elif flag == "en":
document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document) # 英文单字符断句符
document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n', document) # 特殊引号
else:
document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document) # 单字符断句符
document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n',
document) # 特殊引号
sent_list_ori = document.splitlines()
for sent in sent_list_ori:
sent = sent.strip()
if not sent:
continue
elif len(sent) <= 2:
continue
else:
while len(sent) > limit:
temp = sent[0:limit]
sent_list.append(temp)
sent = sent[limit:]
sent_list.append(sent)
except:
sent_list.clear()
sent_list.append(document)
return sent_list
def get_sent(output_path,
input_path,
fin_list=[], host=-1, seq_len=512) -> None:
workers = 32
if input_path[-1] == '/':
input_path = input_path[:-1]
cur_path = os.path.join(output_path, str(host) + '.txt')
new_split_sentence = functools.partial(split_sentence, limit=seq_len-2)
with open(cur_path, 'w', encoding='utf-8') as f:
for fi, fin_path in enumerate(fin_list):
if not os.path.exists(os.path.join(input_path, fin_path[0])):
continue
if '.json' not in fin_path[0]:
continue
print("Processing ", fin_path[0], " ", fi)
with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
f_data = [l['content'] for l in json.load(fin)]
pool = multiprocessing.Pool(workers)
all_sent = pool.imap_unordered(new_split_sentence, f_data, 32)
pool.close()
print('finished..')
cnt = 0
for d in tqdm(all_sent):
for i in d:
f.write(i.strip() + '\n')
f.write(']]' + '\n')
cnt += 1
# if cnt >= 2:
# exit()
def getFileSize(filepath, shard):
all_data = []
for i in os.listdir(filepath):
all_data.append(os.path.join(filepath, i))
all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data])
ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
ans = sorted(ans, key=lambda x: x[1], reverse=True)
per_size = all_size / shard
real_shard = []
temp = []
accu_size = 0
for i in ans:
accu_size += i[1]
temp.append(i)
if accu_size > per_size:
real_shard.append(temp)
accu_size = 0
temp = []
if len(temp) > 0:
real_shard.append(temp)
return real_shard
def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
import socket
host = int(socket.gethostname().split(server_name)[-1])
fin_list = real_shard[server_num * base + host - 1]
print(fin_list)
print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
return fin_list, host
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--server_num', type=int, default=10, help='number of servers')
parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100')
parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus')
parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
args = parser.parse_args()
server_num = args.server_num
seq_len = args.seq_len
shard = args.shard
input_path = args.input_path
output_path = args.output_path
real_shard = getFileSize(input_path, shard)
start = time.time()
for index, shard in enumerate(real_shard):
get_sent(output_path,
input_path,
fin_list=shard,
host=index,
seq_len=seq_len)
print(f'cost {str(time.time() - start)}')
# if you have multiple server, you can use code below or modify code to openmpi
# for i in range(len(real_shard) // server_num + 1):
# fin_list, host = get_start_end(real_shard, i)
# start = time.time()
# get_sent(output_path,
# input_path,
# fin_list=fin_list, host= 10 * i + host - 1)
# print(f'cost {str(time.time() - start)}')