[example] simplify the GPT2 huggingface example (#1826)

pull/1828/head
Jiarui Fang 2022-11-08 16:14:07 +08:00 committed by GitHub
parent cd5a0d56fa
commit b1263d32ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 177 additions and 27611 deletions

View File

@ -1,242 +1,16 @@
# Run GPT With Colossal-AI
## Overview
This example shows how to use ColossalAI to run huggingface GPT training in distributed manners.
In Colossal-AI, there are many ways to run GPT in a distributed manner. The `train_gpt.py` script runs training with the specific configuration scripts in `gpt2_configs/` for different parallelisms of GPT-2 . We have provided some example configuration files of GPT-2 and you can modify them to adapt to your own use.
## GPT
We use the huggingface transformers GPT2 model. The input data is randonly generated.
## How to Prepare Webtext Dataset
## Our Modifications
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
We do not host any datasets for GPT or BERT training, however, we provide a detailed guide on how to prepare the dataset so that our results may be reproduced.
### Overview
We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/openwebtext) library by [jcpeterson](https://github.com/jcpeterson/openwebtext) and [eukaryote31's](https://github.com/eukaryote31/openwebtext) work to download urls to different web pages. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in following section.
### Install necessary packages
**Note: LSH requires GCC's early version. We have tested that version 9.3.0 works, but version 10.3.0 is not.**
## Quick Start
You can launch training by using the following bash script
```bash
pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract cached-path
git clone https://github.com/mattilyra/LSH.git
cd LSH
python setup.py install
pip install -r requirements.txt
bash run.sh
```
If you couldn't install it successfully, you may try to replace the `cMinhash.cpp` in `LSH/lsh` with ours, which is provided in `tools/lsh/cMinhash.cpp`.
### Download Data
1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ).
2. Unzip the zip file and you will get a folder `URLs` which consists of many txt files including urls.
3. Remove blacklisted URLs.
*We appreciate Megatron-LM for making the data preprocessing code public. We have forked Megatron-LM and fixed some bugs. For your convenience, we have collated the needed files in `tools/Megatron`. Click [here](https://github.com/NVIDIA/Megatron-LM.git) to check the source code of Megatron-LM.*
```bash
cd path/to/tools
python Megatron/blacklist_urls.py <path/to/URLs> <path/to/clean_urls.txt>
```
4. Download the content from the clean urls and merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`.
*We have forked and modified [openwebtext](https://github.com/yet-another-account/openwebtext) as there are some bugs in it. For your convenience, we provide our modified version in `tools/download`.*
```bash
python download/download.py <path/to/clean_urls.txt> --n_procs 50 --output <path/to/raw.json>
```
### Prepare Data for GPT Training
1. Perform ftfy, English detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
```bash
python Megatron/cleanup_dataset.py <path/to/raw.json> <path/to/clean.json>
```
Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using `cleanup_fix_dataset.py`. More details can be found by running `python cleanup_fix_dataset.py --help`.
2. Using LSH, find possible duplicates and store them in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by `python find_duplicate.py --help`.
```bash
python Megatron/find_duplicates.py --inputs <path/to/clean.json> url --output <path/to/process_stage_one.json>
```
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
```bash
python Megatron/group_duplicate_url.py <path/to/process_stage_one.json> <path/to/process_stage_two.json>
```
4. Remove similar documents that were detected in the last step. The `dedup.json` is the data after deduplication.
```bash
python Megatron/remove_group_duplicates.py <path/to/process_stage_two.json> <path/to/clean.json> <path/to/dedup.json>
```
5. shuffle the dataset.
```bash
shuf <path/to/dedup.json> -o <path/to/train_data.json>
```
## How to Prepare Yuan Dataset
### Overview
Yuan dataset is a large scale Chinese dataset with 1TB high quality texts proposed by Inspur. You can apply on https://air.inspur.com/home to get access to the dataset. We downloaded and loaded all downloaded content according to the procedure described in following section.
### Download
The dataset can be according to the website once your application is approved.
You also need to download the vocab file from https://github.com/Shawn-Inspur/Yuan-1.0/blob/main/src/vocab.txt
The final data dir should be organized as:
```
|--dataset
| |--001.txt
| |--002.txt
| |--...
|--vocab.txt
```
### Process & Load
Before you run the code, you should replace line 44 in train_gpt.py with
```
import dataset.yuan import YuanDataset
train_ds = YuanDataset(os.environ['DATA'], vocab_path='/path/to/data/vocab.txt'seq_len=gpc.config.SEQ_LEN)
```
Then you can run model following the Usage section. The dataset will be processed when you run it for the first time, and save the cache. Then the data can be loaded automatically.
## **Usage**
```Bash
#!/usr/bin/env sh
export DATA=/path/to/train_data.json
colossalai run --nproc_per_node=<num_gpus> train_gpt.py --config=gpt2_configs/<config_file>
```
You can copy it and save it as `run.sh`. Then use `bash ./run.sh` to run the script in your terminal.
Please modify `DATA`, `num_gpus` and `config_file` with the path to your dataset, the number of GPUs and the config file path, respectively.
If you are going to train gpt3, just replace `gpt2_configs` with `gpt3_configs`.
## GPT-2
Here are the GPT-2 configs' default parameter:
| config | scale | GPU* | batch size | MiB of each GPU | TP | PP | DP |
| ------------ | ----- | ---- | ----------- | --------------- | --- | --- | --- |
| gpt2-vanilla | small | 1 | 1 | 6071 | 1 | 1 | 1 |
| gpt2-vanilla | small | 2 | 1 | 6449*2 | 1 | 1 | 2 |
| gpt2-1d | small | 2 | 1 | 5287*2 | 2 | 1 | 1 |
| gpt2-2d | small | 4 | 1 | 4590*4 | 4 | 1 | 1 |
| gpt-2.5d | small | 8 | 1 | 4815*8 | 8 | 1 | 1 |
| gpt2-3d | small | 8 | 1 | 4901*8 | 8 | 1 | 1 |
| gpt2-pp | small | 2 | 1 | 5877*2 | 1 | 2 | 1 |
| gpt2-zero2 | small | 1 | 1 | 5459 | 1 | 1 | 1 |
| gpt2-zero3 | small | 1 | 1 | 6577 | 1 | 1 | 1 |
| gpt2-nvme | small | 1 | 1 | 5067 | 1 | 1 | 1 |
| gpt2-pp1d | small | 8 | 8 | 5411*8 | 2 | 2 | 2 |
*\*Note: For GPUs, we use Nvidia A100 80G.*
*\*Note: Results of ZeRO are outdated, we will update them soon.*
**We set** `TENSOR_PARALLEL` `PIPELINE_PARALLEL` **and** `DATA_PARALLEL` **as small as it can be to run every demo with the least number of GPUs.**
### **Modify the config file**
#### **General**
There are some **general rules** when modifying the config files.
```Plain%20Text
TP denotes Tensor Parallel
PP denotes Pipeline Parallel
DP denotes Data Parallel
GPUS = TP * PP * DP
Where DP is autoseted
```
You can set the **batch size** and the **epoch** number by changing the number of
`BATCH_SIZE` and `NUM_EPOCHS`, respectively. Then, we will introduce the config file of each mode.
Please note that `gpt2_zero3.py` has nothing but `BATCH_SIZE` and `NUM_EPOCHS` to change.
#### **Vanilla & Data Parallel**
`Vanilla` is the basic mode of GPT-2 with no parallelism at all. However, if you use more than 1 GPU and TP * PP < no. of GPUs, Colossal-AI will **set DP for you** **automatically**.
#### **1D, 2D, 2.5D, 3D**
In files `gpt2_1d.py, gpt2_2d.py, gpt2_2p5d.py, gpt2_3d.py`, there is a line:
```Python
TENSOR_PARALLEL = 2
```
You can modify it to use more tensor parallel, just with the general rules satisfied.
In particular, `TENSOR_PARALLEL` should be a square number and cubic number for 2D and 3D,
respectively, and `TENSOR_PARALLEL / DEPTH` should be a square number for 2.5D.
#### **Pipeline Parallel**
To use pipeline parallel training, you should install colossalai from the **latest** main branch.
In `gpt2_pp.py`, there are lines:
```Python
# BATCH_SIZE / NUM_MICRO_BATCHES should be an integer
NUM_MICRO_BATCHES = 1
PIPELINE = 2
```
#### **Pipeline + 1D + Data Parallel**
In `gpt2_pp1d.py`, we have
```Python
BATCH_SIZE = 8
NUM_EPOCHS = 60
NUM_MICRO_BATCHES = 1
HIDDEN_SIZE = 768
PIPELINE = 2
TENSOR_PARALLEL = 2
MODE = '1d'
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
```
We have introduced `BATCH_SIZE`, `NUM_EPOCHS`, `NUM_MICRO_BATCHES`, `PIPELINE`, `TENSOR_PARALLEL` as discussed above.
`HIDDEN_SIZE` refers to the hidden dimension of the model, i.e. `gpt2_small` is 768.
You can choose `None, '1d', '2d', '2.5d', '3d'` for `MODE`.
## GPT-3
GPT-3 is a really huge model, for which it seems not possible to train it with a little number of GPUs. Therefore, we choose some common sets of parameters instead of the smallest ones.
Here are our default parameters of GPT-3 configs:
| config | GPU* | batch size | TP | PP | DP |
| -------------- | ---- | ---------- | --- | --- | --- |
| gpt3_pp1d_min | 96 | 192 | 4 | 24 | 1 |
| gpt3_pp1d | 128 | 192 | 4 | 32 | 1 |
| gpt3_pp2d | 96 | 2*48 | 4 | 24 | 1 |
| gpt3_pp2p5d | 96 | 2*48 | 4 | 24 | 1 |
| gpt3_zero3_min | 64 | 3 | 1 | 1 | 64 |
| gpt3_zero3 | 96 | 2 | 1 | 1 | 96 |
*\*Note: we use Nvidia A100 40G GPUs*
*\*Note: Results of ZeRO are outdated, we will update them soon.*
In the figure above, the suffix `_min` means the set of hyper-parameters requires the least number of GPUs with the same mode.
GPT-3 and GPT-2 have the same set of hyper-parameters.

View File

@ -1,39 +0,0 @@
import json
import os
import torch
from torch.utils.data import Dataset
from colossalai.registry import DATASETS
from transformers import GPT2Tokenizer
@DATASETS.register_module
class WebtextDataset(Dataset):
def __init__(self, path, seq_len=1024) -> None:
super().__init__()
root = os.path.dirname(path)
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
if os.path.isfile(encoded_data_cache_path):
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
if seq_len_ == seq_len:
self.data = data
self.attention_mask = attention_mask
return
raw_data = []
with open(path) as f:
for line in f.readlines():
raw_data.append(json.loads(line)['text'])
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.unk_token
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
self.data = encoded_data['input_ids']
self.attention_mask = encoded_data['attention_mask']
torch.save((seq_len, self.data, self.attention_mask), encoded_data_cache_path)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index]

View File

@ -1,329 +0,0 @@
import collections
import glob
import logging
import multiprocessing
import os
import sys
import jieba
import six
import torch
from tools.tokenization_enc_dec import EncDecTokenizer
from torch.utils.data import Dataset
from tqdm import tqdm
from colossalai.registry import DATASETS
try:
import nltk
nltk_available = True
except ImportError:
nltk_available = False
jieba.setLogLevel(logging.INFO)
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def is_contain_chinese(check_str):
for ch in check_str:
if u'\u4e00' <= ch <= u'\u9fff':
return True
return False
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Should be running on Python 3")
class WordpieceTokenizer(object):
def __init__(self, vocab, unk_token="<unk>", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, token):
token = convert_to_unicode(token)
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
return [self.unk_token]
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if is_contain_chinese(substr):
if substr in self.vocab:
cur_substr = substr
break
else:
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
sub_tokens.append(self.unk_token)
start += 1
continue
sub_tokens.append(cur_substr)
start = end
return sub_tokens
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r", encoding='utf-8') as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
class EncDecTokenizer(object):
def __init__(self, vocab_file, max_len=None, max_sentinels=190):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = load_vocab(vocab_file)
self.decoder = {v: k for k, v in self.encoder.items()}
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.encoder)
self.translator = str.maketrans(" \n", "\u2582\u2583")
self.sentinel_list = [self.encoder['<s_{}>'.format(i)] for i in range(max_sentinels)]
self.en_vocab = {}
for k, v in self.encoder.items():
if is_contain_chinese(k):
self.en_vocab[v] = False
else:
self.en_vocab[v] = True
self.en_vocab[10] = False
@property
def vocab_size(self):
return len(self.encoder)
def __len__(self):
return len(self.encoder)
@property
def eod_id(self):
return self.encoder[self.eod_token]
@property
def pad_id(self):
return self.encoder[self.pad_token]
@property
def eod_token(self):
return '<eod>'
@property
def pad_token(self):
return '<pad>'
def get_sentinel_num(self):
return len(self.sentinel_list)
def get_sentinel_id(self, idx):
return self.sentinel_list[idx]
def tokenize(self, text):
""" Tokenize a string. """
output_tokens = []
for x in jieba.cut(text, cut_all=False):
x = x.translate(self.translator)
output_tokens.extend(self.wordpiece_tokenizer.tokenize(x))
# print(output_tokens)
return output_tokens
def encode(self, text):
output_tokens = [self.encoder[x] for x in self.tokenize(text)]
# filter space
new_output_tokens = [output_tokens[0]]
for i, x in enumerate(output_tokens[1:-1]):
if x == 10:
if self.en_vocab[output_tokens[i]] and self.en_vocab[output_tokens[i + 2]]:
continue
new_output_tokens.append(x)
new_output_tokens.append(output_tokens[-1])
return new_output_tokens
def decode(self, tokens):
new_tokens = []
for i, x in enumerate(tokens[:-1]):
if self.en_vocab[x] and self.en_vocab[tokens[i + 1]]:
new_tokens.append(x)
new_tokens.append(10)
else:
new_tokens.append(x)
new_tokens.append(tokens[-1])
# text = ''.join([self.decoder[x] for x in new_tokens])
# text = text.replace('\u2582', ' ').replace('\u2583', '\n')
# return text
return [self.decoder[x] for x in tokens]
class IdentitySplitter(object):
@staticmethod
def tokenize(*text):
return text
class Encoder(object):
def __init__(self, vocab_path, length, sentence_splitter):
self.vocab_path = vocab_path
self.length = length
self.sentence_splitter = sentence_splitter
self.tokenizer = EncDecTokenizer(os.path.join(self.vocab_path))
self.splitter = IdentitySplitter()
def initializer(self):
# Use Encoder class as a container for global data
pass
def encode(self, line):
# end with <eod>
if len(line) > 20000:
return None, 0
if len(line) < 10:
return None, 0
data = line.strip().strip('<n>')
data = data.replace("<n>", "\n")
doc_ids = self.tokenizer.encode(data)
doc_ids.append(self.tokenizer.eod_id)
return doc_ids, len(line)
@DATASETS.register_module
class YuanDataset(Dataset):
"""
Yuan is an open source Chinese dataset, which can be accessed on https://github.com/Shawn-Inspur/Yuan-1.0.
Args:
path(str): Path to dataset's folder, raw data should be organized under the folder as 001.txt, 002.txt...
eg:/path/yuan/dataset
vocab_path(str): Path to the vocab file. eg:/path/yuan/vocab.txt
seq_len(int): Sequence length of the transformer, defaults to 2048.
"""
def __init__(self, path, vocab_path, seq_len=2048) -> None:
super().__init__()
self.input_path = path
workers = 16
sentence_splitter = None
self.vocab_path = vocab_path
self.pad_id = EncDecTokenizer(os.path.join(self.vocab_path)).pad_id
self.length = seq_len
if self.input_path[-1] == '/':
self.input_path = self.input_path[:-1]
if os.path.exists(os.path.join(self.input_path, 'data_list.pt')):
self.data_path = torch.load(os.path.join(self.input_path, 'data_list.pt'))
return
fin_list = glob.glob(self.input_path + '/0[0-9][0-9].txt')
self.data_path = []
for fin_path in fin_list:
if not os.path.exists(fin_path):
continue
if '.txt' not in fin_path:
continue
all_data = []
print("Processing ", fin_path)
with open(fin_path, 'r', encoding='utf-8', errors='ignore') as fin:
encoder = Encoder(self.vocab_path, seq_len, sentence_splitter)
pool = multiprocessing.Pool(workers, initializer=encoder.initializer)
encoded_docs = pool.imap_unordered(encoder.encode, fin, 30)
for i, (no_noise_tokens, bytes_processed) in tqdm(enumerate(encoded_docs, start=1)):
if no_noise_tokens is None:
continue
all_data.append(no_noise_tokens)
pool.close()
print('Saving ', fin_path)
base_path = fin_path.replace('.txt', '')
if not os.path.exists(base_path):
os.mkdir(base_path)
idx = 0
for d in tqdm(all_data):
idx += 1
cur_path = os.path.join(base_path, str(idx) + '.txt')
with open(cur_path, 'w+', encoding='utf-8') as f:
for i in d:
f.write(str(i) + ' ')
f.write('\n')
self.data_path.append(cur_path.replace(self.input_path + '/', ''))
torch.save(self.data_path, os.path.join(self.input_path, 'data_list.pt'))
def __len__(self):
return len(self.data_path)
def __getitem__(self, index):
path = self.data_path[index]
root = os.path.join(self.input_path, path)
with open(root, "r") as f:
data = f.readlines()
assert len(data) == 1
data = data[0][:-2].split(' ')
try:
data = list(map(int, data))
except:
while '' in data:
data.remove('')
data = list(map(int, data))
if len(data) > self.length:
data = data[:self.length - 1] + [data[-1]]
mask = [1] * self.length
else:
data += [self.pad_id] * (self.length - len(data))
mask = [1] * len(data) + [0] * (self.length - len(data))
data = torch.tensor(data)
mask = torch.tensor(mask)
return {'input_ids': data, 'attention_mask': mask}, data
if __name__ == '__main__':
dataset = YuanDataset('/data/gpt-yuan/ASC22/dataset', vocab_path='/data/gpt-yuan/ASC22/vocab.txt', seq_len=2048)
test = dataset.__getitem__(0)
print(test)

View File

@ -1,31 +0,0 @@
from titans.loss.lm_loss import GPTLMLoss
from titans.model.gpt import gpt2_small
from torch.optim import Adam
from colossalai.amp import AMP_TYPE
BATCH_SIZE = 1
SEQ_LEN = 1024
NUM_EPOCHS = 60
TENSOR_PARALLEL = 2
optimizer = dict(
type=Adam,
lr=0.00015,
weight_decay=1e-2,
)
fp16 = dict(mode=AMP_TYPE.NAIVE)
loss = dict(type=GPTLMLoss,)
model = dict(
type=gpt2_small,
checkpoint=True,
)
parallel = dict(
pipeline=1,
tensor=dict(size=TENSOR_PARALLEL, mode='1d'),
)

View File

@ -1,30 +0,0 @@
from titans.loss.lm_loss import GPTLMLoss
from titans.model.gpt import gpt2_small
from torch.optim import Adam
from colossalai.amp import AMP_TYPE
BATCH_SIZE = 4
SEQ_LEN = 1024
NUM_EPOCHS = 60
TENSOR_PARALLEL = 4
optimizer = dict(
type=Adam,
lr=0.00015,
weight_decay=1e-2,
)
fp16 = dict(mode=AMP_TYPE.NAIVE)
loss = dict(type=GPTLMLoss,)
model = dict(
type=gpt2_small,
checkpoint=True,
)
parallel = dict(
pipeline=1,
tensor=dict(size=TENSOR_PARALLEL, mode='2d'),
)

View File

@ -1,31 +0,0 @@
from titans.loss.lm_loss import GPTLMLoss
from titans.model.gpt import gpt2_small
from torch.optim import Adam
from colossalai.amp import AMP_TYPE
BATCH_SIZE = 4
SEQ_LEN = 1024
NUM_EPOCHS = 60
TENSOR_PARALLEL = 8
DEPTH = 2
optimizer = dict(
type=Adam,
lr=0.00015,
weight_decay=1e-2,
)
fp16 = dict(mode=AMP_TYPE.NAIVE)
loss = dict(type=GPTLMLoss,)
model = dict(
type=gpt2_small,
checkpoint=True,
)
parallel = dict(
pipeline=1,
tensor=dict(size=TENSOR_PARALLEL, depth=DEPTH, mode='2.5d'),
)

View File

@ -1,30 +0,0 @@
from titans.loss.lm_loss import GPTLMLoss
from titans.model.gpt import gpt2_small
from torch.optim import Adam
from colossalai.amp import AMP_TYPE
BATCH_SIZE = 4
SEQ_LEN = 1024
NUM_EPOCHS = 60
TENSOR_PARALLEL = 8
optimizer = dict(
type=Adam,
lr=0.00015,
weight_decay=1e-2,
)
fp16 = dict(mode=AMP_TYPE.NAIVE)
loss = dict(type=GPTLMLoss,)
model = dict(
type=gpt2_small,
checkpoint=True,
)
parallel = dict(
pipeline=1,
tensor=dict(size=TENSOR_PARALLEL, mode='3d'),
)

View File

@ -1,33 +0,0 @@
from titans.loss.lm_loss import GPTLMLoss
from titans.model.gpt import gpt2_small
#from model_zoo.gpt.gpt import gpt2_small_pipeline
from torch.optim import Adam
from colossalai.amp import AMP_TYPE
BATCH_SIZE = 8
SEQ_LEN = 1024
NUM_EPOCHS = 60
HIDDEN_SIZE = 768
NUM_MICRO_BATCHES = 4
PIPELINE = 2
optimizer = dict(
type=Adam,
lr=0.00015,
weight_decay=1e-2,
)
fp16 = dict(mode=AMP_TYPE.NAIVE)
loss = dict(type=GPTLMLoss,)
model = dict(
type=gpt2_small,
checkpoint=True,
)
parallel = dict(
pipeline=PIPELINE,
tensor=dict(size=1, mode=None),
)

View File

@ -1,35 +0,0 @@
import torch
from titans.loss.lm_loss import GPTLMLoss
from titans.loss.vocab_cross_entropy import vocab_parallel_cross_entropy
from titans.model.gpt import gpt2_small
from torch.optim import Adam
from colossalai.amp import AMP_TYPE
BATCH_SIZE = 8
NUM_EPOCHS = 60
SEQ_LEN = 1024
NUM_MICRO_BATCHES = 4
HIDDEN_SIZE = 768
PIPELINE = 2
TENSOR_PARALLEL = 2
MODE = '1d'
fp16 = dict(mode=AMP_TYPE.NAIVE)
parallel = dict(pipeline=PIPELINE, tensor=dict(mode=MODE, size=TENSOR_PARALLEL))
optimizer = dict(
type=Adam,
lr=0.00015,
weight_decay=1e-2,
)
model = dict(
type=gpt2_small,
checkpoint=True,
dtype=torch.half,
)
loss_fn = dict(type=vocab_parallel_cross_entropy)

View File

@ -1,26 +0,0 @@
from titans.model.gpt import gpt2_small
from torch.optim import Adam
from colossalai.amp import AMP_TYPE
BATCH_SIZE = 1
NUM_EPOCHS = 60
SEQ_LEN = 1024
optimizer = dict(
type=Adam,
lr=0.00015,
weight_decay=1e-2,
)
fp16 = dict(mode=AMP_TYPE.NAIVE)
model = dict(
type=gpt2_small,
checkpoint=True,
)
parallel = dict(
pipeline=1,
tensor=dict(size=1, mode=None),
)

View File

@ -1,24 +0,0 @@
from titans.model.gpt import gpt2_small
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero.shard_utils import TensorShardStrategy
BATCH_SIZE = 2
NUM_EPOCHS = 60
SEQ_LEN = 1024
zero = dict(model_config=dict(tensor_placement_policy='auto',
shard_strategy=TensorShardStrategy(),
reuse_fp16_shard=True),
optimizer_config=dict())
optimizer = dict(
type=HybridAdam,
lr=0.00015,
weight_decay=1e-2,
)
model = dict(
type=gpt2_small,
checkpoint=True,
)

View File

@ -1,26 +0,0 @@
from model import GPT2_small_pipeline_hybrid
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
BATCH_SIZE = 8
NUM_EPOCHS = 60
SEQ_LEN = 1024
NUM_MICRO_BATCHES = 4
HIDDEN_SIZE = 768
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
zero = dict(model_config=dict(tensor_placement_policy='cpu', shard_strategy=BucketTensorShardStrategy()),
optimizer_config=dict())
optimizer = dict(
type=HybridAdam,
lr=0.00015,
weight_decay=1e-2,
)
model = dict(type=GPT2_small_pipeline_hybrid, checkpoint=True, num_chunks=1)
parallel = dict(
pipeline=2,
tensor=dict(size=2, mode='1d'),
)

View File

@ -1,30 +0,0 @@
import torch
from titans.loss.vocab_cross_entropy import vocab_parallel_cross_entropy
from titans.model.gpt import gpt3
from torch.optim import Adam
from colossalai.amp import AMP_TYPE
BATCH_SIZE = 192
NUM_EPOCHS = 60
SEQ_LEN = 2048
NUM_MICRO_BATCHES = 192
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, 12288)
fp16 = dict(mode=AMP_TYPE.NAIVE)
parallel = dict(pipeline=32, tensor=dict(mode='1d', size=4))
optimizer = dict(
type=Adam,
lr=0.00015,
weight_decay=1e-2,
)
model = dict(
type=gpt3,
checkpoint=True,
dtype=torch.half,
)
loss_fn = dict(type=vocab_parallel_cross_entropy)

View File

@ -1,30 +0,0 @@
import torch
from titans.loss.vocab_cross_entropy import vocab_parallel_cross_entropy
from titans.model.gpt import gpt3
from torch.optim import Adam
from colossalai.amp import AMP_TYPE
BATCH_SIZE = 192
NUM_EPOCHS = 60
SEQ_LEN = 2048
NUM_MICRO_BATCHES = 192
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, 12288)
fp16 = dict(mode=AMP_TYPE.NAIVE)
parallel = dict(pipeline=24, tensor=dict(mode='1d', size=4))
optimizer = dict(
type=Adam,
lr=0.00015,
weight_decay=1e-2,
)
model = dict(
type=gpt3,
checkpoint=True,
dtype=torch.half,
)
loss_fn = dict(type=vocab_parallel_cross_entropy)

View File

@ -1,27 +0,0 @@
import torch
from titans.model.gpt import gpt3
from torch.optim import Adam
from colossalai.amp import AMP_TYPE
BATCH_SIZE = 2 * 48
NUM_EPOCHS = 60
SEQ_LEN = 2048
NUM_MICRO_BATCHES = 48
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES // 2, SEQ_LEN, 12288 // 2)
fp16 = dict(mode=AMP_TYPE.NAIVE)
parallel = dict(pipeline=24, tensor=dict(mode='2d', size=4))
optimizer = dict(
type=Adam,
lr=0.00015,
weight_decay=1e-2,
)
model = dict(
type=gpt3,
checkpoint=True,
dtype=torch.half,
)

View File

@ -1,27 +0,0 @@
import torch
from titans.model.gpt import gpt3
from torch.optim import Adam
from colossalai.amp import AMP_TYPE
BATCH_SIZE = 2 * 48
NUM_EPOCHS = 60
SEQ_LEN = 2048
NUM_MICRO_BATCHES = 48
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES // 2, SEQ_LEN, 12288 // 2)
fp16 = dict(mode=AMP_TYPE.NAIVE)
parallel = dict(pipeline=24, tensor=dict(mode='2.5d', depth=1, size=4))
optimizer = dict(
type=Adam,
lr=0.00015,
weight_decay=1e-2,
)
model = dict(
type=gpt3,
checkpoint=True,
dtype=torch.half,
)

View File

@ -0,0 +1,3 @@
colossalai >= 0.1.10
torch >= 1.8.1
transformers >= 4.231

View File

@ -1,7 +1 @@
export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
export NODE_RANK=${NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
export MASTER_PORT=${MASTER_PORT:-"12345"}
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=2 train_gpt.py --config=gpt2_configs/gpt2_zero3.py --from_torch 2>&1 | tee logs/log
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=2 train_gpt_demo.py 2>&1 | tee run.log

File diff suppressed because it is too large Load Diff

View File

@ -1,307 +0,0 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import re
import sys
import time
import tldextract
# List of the domains to blacklist.
domain_blacklist = set([
'500px',
'aapks',
'akamaihd',
'amazon',
'apple',
'artifactfire',
'artstation',
'awwni',
'bandcamp',
'battleforthenet',
'coinscalendar',
'dailymotion',
'deviantart',
'discord',
'discordapp',
'dlapkandroid',
'dropbox',
'e621',
'ebay',
'edealinfo',
'erome',
'eroshare',
'explosm',
'facebook',
'fbcdn',
'flickr',
'furaffinity',
'futhead',
'gatopardo',
'gfycat',
'gifsound',
'gifsoup',
'giphy',
'github',
'google',
'gunprime',
'gyazo',
'hotdealstar',
'imagefap',
'imageshack',
'imgflip',
'imgur',
'instagram',
'karmadecay',
'kryptocal',
'kym-cdn',
'liveleak',
'livememe',
'lmgtfy',
'magaimg',
'memegenerator',
'minorplanetcenter',
'minus',
'mobafire',
'morejpeg',
'nocookie',
'pcpartpicker',
'photobucket',
'pinimg',
'pinterest',
'pixiv',
'pornhub',
'prntscr',
'puu',
'qkme',
'quickmeme',
'radd',
'redd',
'reddit',
'reddit-stream',
'redditlog',
'redditmedia',
'reddituploads',
'redtube',
'reupp',
'reverb',
'roanoke',
'rollingstone',
'sli',
'soundcloud',
'soundgasm',
'spankbang',
'spotify',
'strawpoll',
'streamable',
'timeanddate',
'tinypic',
'touhouradio',
'tumblr',
'twimg',
'twitch',
'twitter',
'vid',
'vimeo',
'vine',
'vkaao',
'vocaroo',
'voyagefusion',
'walmart',
'wciu',
'wikimedia',
'wikipedia',
'xhamster',
'xkcd',
'xvideos',
'youtu',
'youtube',
'youtubedoubler',
'ytimg',
'zillexplorer',
])
def domain_is_in_blacklist(url):
domain = tldextract.extract(url).domain
return domain in domain_blacklist
# List of extentions to blacklist.
extentions_blacklist = (
'.3gp',
'.7z'
'.ai',
'.aif',
'.apk',
'.app',
'.avi',
'.bin',
'.bmp',
'.bz2',
'.css',
'.csv',
'.dat',
'.deb',
'.dmg',
'.doc',
'.docx',
'.exe',
'.gif',
'.gifv',
'.gz',
'.iso',
'.jar',
'.jpeg',
'.jpg',
'.js',
'.log',
'.mid',
'.midi',
'.mkv',
'.mov',
'.mp3',
'.mp4',
'.mpeg',
'.mpg',
'.ogg',
'.ogv',
'.otf',
'.pdf',
'.pkg',
'.png',
'.pps',
'.ppt',
'.pptx',
'.psd',
'.py',
'.qt',
'.ram',
'.rar',
'.sql',
'.svg',
'.swf',
'.tar.gz',
'.tar',
'.tgz',
'.tiff',
'.ttf',
'.txt',
'.wav',
'.webm',
'.wma',
'.wmv',
'.xls',
'.xlsx',
'.xml',
'.xz',
'.zip',
)
def extention_is_in_blacklist(url):
if url.split('?')[0].lower().endswith(extentions_blacklist):
return True
return False
# Malformed urls.
# This function is adapted from:
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
url_regex = re.compile(
r'^(?:http)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$',
re.IGNORECASE)
def url_is_malformed(url):
return re.match(url_regex, url) is None
def print_progress(prefix, start_time, urls_counter, domain_blacklist_counter, extention_blacklist_counter,
short_url_counter, malformed_url_counter, duplicate_url_counter):
string = prefix + ' | '
string += 'time elapsed (s): {:.2f} | '.format(time.time() - start_time)
string += 'number of urls: {} | '.format(urls_counter)
string += 'domain blacklisted: {} | '.format(domain_blacklist_counter)
string += 'extention blacklisted: {} | '.format(extention_blacklist_counter)
string += 'short urls (<=8): {} | '.format(short_url_counter)
string += 'malformed urls: {} | '.format(malformed_url_counter)
string += 'duplicate urls: {}'.format(duplicate_url_counter)
print(string, flush=True)
if __name__ == '__main__':
print('remove blacklisted urls ..')
# Path to the url files.
path = sys.argv[1]
# Output url file.
output = sys.argv[2]
# Get the list of url files.
files = glob.glob(path + '/*.txt')
print('> found {} files'.format(len(files)))
urls = set()
urls_counter = 0
domain_blacklist_counter = 0
extention_blacklist_counter = 0
short_url_counter = 0
malformed_url_counter = 0
duplicate_url_counter = 0
start_time = time.time()
for filename in files:
with open(filename, 'r') as f:
for line in f:
url = line.strip()
urls_counter += 1
if domain_is_in_blacklist(url):
print('[DOMAIN BLACKLIST]: {}'.format(url), flush=True)
domain_blacklist_counter += 1
elif extention_is_in_blacklist(url):
print('[EXTENTION BLACKLIST]: {}'.format(url), flush=True)
extention_blacklist_counter += 1
elif len(url) <= 8:
print('[SHORT URL]: {}'.format(url), flush=True)
short_url_counter += 1
elif url_is_malformed(url):
print('[MALFORMED URL]: {}'.format(url), flush=True)
malformed_url_counter += 1
elif url in urls:
print('[DUPLICATE URL]: {}'.format(url), flush=True)
duplicate_url_counter += 1
else:
urls.add(url)
if urls_counter % 100000 == 0:
print_progress('PROGRESS', start_time, urls_counter, domain_blacklist_counter,
extention_blacklist_counter, short_url_counter, malformed_url_counter,
duplicate_url_counter)
print_progress('FINAL', start_time, urls_counter, domain_blacklist_counter, extention_blacklist_counter,
short_url_counter, malformed_url_counter, duplicate_url_counter)
# Write the final set of urls.
print('> writing cleaned up url list to {}'.format(output))
with open(output, 'w') as f:
for url in urls:
f.write(url + '\n')
print('done :-)')

View File

@ -1,107 +0,0 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import time
import ftfy
import numpy as np
from langdetect import detect
from tokenizer import Tokenizer
MIN_DOCUMENT_LENGTH = 128
def print_progress(prefix, start_time, num_docs, num_fixed_text, num_non_english_docs, chars_non_english_docs,
num_small_docs, chars_small_docs):
string = prefix + ' | '
string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
string += 'documents: {} | '.format(num_docs)
string += 'fixed text: {} | '.format(num_fixed_text)
string += 'non-english: {} | '.format(num_non_english_docs)
string += 'non-english chars: {} | '.format(chars_non_english_docs)
string += 'small docs: {} | '.format(num_small_docs)
string += 'small docs chars: {}'.format(chars_small_docs)
print(string, flush=True)
def filter_corpus(filename, out_filename, print_interval=10000):
print(' > filtering {}'.format(filename))
tokenizer = Tokenizer(cache_dir='./cache')
num_docs = 0
num_written_docs = 0
num_small_docs = 0
num_fixed_text = 0
num_non_english_docs = 0
chars_non_english_docs = 0
chars_small_docs = 0
start_time = time.time()
with open(out_filename, 'wb') as f:
with open(filename, 'r') as fin:
for line in fin:
try:
num_docs += 1
myjson = json.loads(line)
# Fix text
text = ftfy.fix_text(myjson['text'])
if text != myjson['text']:
num_fixed_text += 1
myjson['text'] = text
# Detect language.
if detect(text) != 'en':
print('[non-english text]', myjson)
num_non_english_docs += 1
chars_non_english_docs += len(text)
continue
# On average each token is 5 characters so 8 is an
# upper bound.
if len(text) < (8 * MIN_DOCUMENT_LENGTH):
tokens = tokenizer.tokenize_document(text)
if len(tokens) < MIN_DOCUMENT_LENGTH:
print('[small document, skipping]:', myjson)
num_small_docs += 1
chars_small_docs += len(text)
continue
myjson = json.dumps(myjson, ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
num_written_docs += 1
if num_docs % print_interval == 0:
print_progress('[PROGRESS]', start_time, num_docs, num_fixed_text, num_non_english_docs,
chars_non_english_docs, num_small_docs, chars_small_docs)
except Exception as e:
print(' skipping ', line, e)
print_progress('[FINAL]', start_time, num_docs, num_fixed_text, num_non_english_docs, chars_non_english_docs,
num_small_docs, chars_small_docs)
if __name__ == '__main__':
print('building gpt2 dataset ...')
input_filename = sys.argv[1]
output_filename = sys.argv[2]
print('will be reading {}'.format(input_filename))
print('and will write the results to {}'.format(output_filename))
filter_corpus(input_filename, output_filename)

View File

@ -1,191 +0,0 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Filter and clean documents:
Capable to clean docs with less than 512 characters, less than
256 characters and contains javascript, fix text and dataset specific
cleaning like stories and realnews datasets.
Program arguments have the details.
"""
import argparse
import glob
import json
import multiprocessing
import os
import re
import time
from functools import partial
from pathlib import Path
import ftfy
from langdetect import detect
def process_doc(json_line, args):
# Read the line.
document = json.loads(json_line)
text = document['text']
output = {'remove_512': False, 'remove_256_javascript': False, \
'remove_512_non_english': False, 'ftfy_fix_text': False, \
'general_cleaning': False}
try:
# Remove all docs with less than 512 characters
if "remove_512" in args.tasks:
if len(text) < 512:
output['remove_512'] = True
return output, text, document, True
# Remove docs if less than 256 character length and contains Javascript
if "remove_256_javascript" in args.tasks:
if len(text) < 256 and 'javascript' in text.lower():
output['remove_256_javascript'] = True
return output, text, document, True
# Remove docs < 512 and nonenglish
if "remove_512_non_english" in args.tasks:
if len(text) < 512 and detect(text) != 'en':
output['remove_512_non_english'] = True
return output, text, document, True
# Fix the text using ftfy, don't remove the text, hence return False
if "ftfy_fix_text" in args.tasks:
fixed_text = ftfy.fix_text(text)
output['ftfy_fix_text'] = True
return output, fixed_text, document, False
# Cleaning extra spaces and newlines
if "general_cleaning" in args.tasks:
cleaned_text = re.sub(r" +|\b\n+ |\b\n+", " ", text)
#cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
#cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews
# stories datasets
#cleaned_text = re.sub(r" \'", "'", text)
#cleaned_text = re.sub(r" \!", "!", cleaned_text)
#cleaned_text = re.sub(r" \.", ".", cleaned_text)
#cleaned_text = re.sub(r" \?", "?", cleaned_text)
#cleaned_text = re.sub(r" - ", "-", cleaned_text)
##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
#cleaned_text = re.sub(r" @ ", "@", cleaned_text)
output['general_cleaning'] = True
return output, cleaned_text, document, False
except Exception as e:
print('Error: *************************\n{}\ntext: {}'.format(e, \
text), flush=True)
return output, text, document, True
# don't remove
return output, text, document, False
def process_set(args, input_file, output_f_cleaned, output_f_filtered):
print(' > working on {} ...'.format(input_file), flush=True)
num_docs = num_remove_512 = num_remove_java = num_remove_512_non_english \
= num_ftfy_fix_text = num_general_cleaning = 0
# Output file and counters.
output_cleaned = open(output_f_cleaned, 'wb')
output_filtered = open(output_f_filtered, 'wb')
start_time = time.time()
# Setup multi-processing.
num_workers = 40
fin = open(input_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers)
process_doc_partial = partial(process_doc, args=args)
processed_docs = pool.imap(process_doc_partial, fin, 500)
# Process documents.
for output, text, document, to_filter in processed_docs:
num_docs += 1
num_remove_512 += 1 if output['remove_512'] else 0
num_remove_java += 1 if output['remove_256_javascript'] else 0
num_remove_512_non_english += 1 if output['remove_512_non_english'] \
else 0
num_ftfy_fix_text += 1 if output['ftfy_fix_text'] else 0
num_general_cleaning += 1 if output['general_cleaning'] else 0
document['text'] = text
myjson = json.dumps(document, ensure_ascii=False)
if to_filter:
output_filtered.write(myjson.encode('utf-8'))
output_filtered.write('\n'.encode('utf-8'))
else:
output_cleaned.write(myjson.encode('utf-8'))
output_cleaned.write('\n'.encode('utf-8'))
if num_docs % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format(num_docs,
time.time() - start_time),
flush=True)
# Close the file.
output_cleaned.close()
output_filtered.close()
fin.close()
# Print stats.
print(' >> total docs: {} remove_512 {} remove_256_javascript {} '\
'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'.\
format(num_docs, num_remove_512, num_remove_java,\
num_remove_512_non_english, num_ftfy_fix_text, \
num_general_cleaning), flush=True)
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-files', nargs = '*', required=True, default=\
None, help = 'Input json files that needs to be'\
' cleaned')
parser.add_argument('--tasks', nargs = '*', required=True, default=None,\
help = 'Tasks to perform on the input files, ' \
'such as remove_512, remove_256_javascript, ' \
'remove_512_non_english, ftfy_fix_text, and ' \
'general_cleaning. 256 or 512 means the number' \
' of characters.')
parser.add_argument('--output-path', type=str, default=None, help='Directory where the output should go')
parser.add_argument('--log-interval', type=int, default=100, help='Log interval')
args = parser.parse_args()
print('cleanup dataset ...')
for input_file in args.input_files:
input_filename, input_filename_ext = os.path.splitext(Path(input_file)\
.name)
output_f_cleaned = os.path.join(args.output_path, input_filename + \
"_cleaned" + input_filename_ext)
output_f_filtered = os.path.join(args.output_path, input_filename + \
"_filtered" + input_filename_ext)
process_set(args, input_file, output_f_cleaned, output_f_filtered)
print('done :-)', flush=True)

View File

@ -1,314 +0,0 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
import json
import multiprocessing
import os
import pickle
import sys
import time
from functools import partial
import numpy as np
from lsh import cache, minhash
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
return set(text[head:head + char_ngram] for head in range(0, len(text) - char_ngram))
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def jaccard(set_a, set_b, args):
if len(set_a) < 1 or len(set_b) < 1:
return 0.0
intersection = set_a & set_b
union = set_a | set_b
if args.jaccard == 'min':
return len(intersection) / min(len(set_a), len(set_b))
elif args.jaccard == 'max':
return len(intersection) / max(len(set_a), len(set_b))
else:
return len(intersection) / len(union)
def compute_fingerprint(line, key):
try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
fingerprint = hasher.fingerprint(text)
except Exception as e:
print('Error:', e)
return None, None, None, False
return url, text, fingerprint, True
def url_pairs_to_remove(args, bucket_urls, url_doc):
remove_urls_list = []
deduped_local, counter_local = 0, 0
iteration = 0
while len(bucket_urls) > 1:
if args.heuristic_iter != -1 and \
iteration == args.heuristic_iter:
break
items = list(bucket_urls)
remove_urls = []
main_url = items[np.random.randint(0, len(items))]
main_shingles = shingles(url_doc[main_url])
for i in range(0, len(items)):
counter_local += 1
other_url = items[i]
if other_url == main_url:
continue
other_shingles = shingles(url_doc[other_url])
try:
jaccard_sim = jaccard(main_shingles, other_shingles, args)
except Exception as e:
print('Error:', e)
jaccard_sim = 0.0
if jaccard_sim > 0.5:
remove_urls.append({other_url: jaccard_sim})
deduped_local += 1
bucket_urls.remove(other_url)
bucket_urls.remove(main_url)
if len(remove_urls) > 0:
remove_urls_list.append({main_url: remove_urls})
iteration += 1
return remove_urls_list, deduped_local, counter_local
def write_remove_urls_list(remove_urls_list, f_out):
if len(remove_urls_list) > 0:
for each_url_remove in remove_urls_list:
myjson = json.dumps(each_url_remove, ensure_ascii=False)
f_out.write(myjson.encode('utf-8'))
f_out.write('\n'.encode('utf-8'))
def compute_jaccard(each_bin, num_bins, start_time_local):
remove_urls_list = []
deduped_local, counter_local, bucket_local = 0, 0, 0
for bucket_id in each_bin:
bucket_local += 1
if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
print("Counter {}, progress {:.2f} time {:.2f}".\
format(bucket_local, float(bucket_local)/float(len(each_bin)),\
time.time() - start_time_local), flush=True)
if len(each_bin[bucket_id]) <= 1:
continue
bucket_urls = each_bin[bucket_id].copy()
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
url_pairs_to_remove(args, bucket_urls, url_doc)
deduped_local += deduped_local_sub
counter_local += counter_local_sub
if len(remove_urls_list_sub) > 0:
remove_urls_list.extend(remove_urls_list_sub)
return remove_urls_list, deduped_local, counter_local
def find_pair_urls_parallel(args, lshcache, url_doc):
start_time = time.time()
f_out = open(args.output, 'wb')
deduped, counter = 0, 0
# compute jaccards of buckets in bin in parallel (parallelism
# limited to # of bins)
num_bins = len(lshcache.bins)
pool = multiprocessing.Pool(num_bins)
compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
start_time_local=start_time)
# don't need to pass args and url_doc as they are already shared
compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)
print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
flush=True)
for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
deduped += deduped_local
counter += counter_local
write_remove_urls_list(remove_urls_list, f_out)
print(' [write]> processed {} documents in {:.2f} '
'seconds and deduped {} documents ...'.format(counter, time.time()\
- start_time, deduped), flush=True)
pool.close()
pool.join()
f_out.close()
print(' Taken time for jaccard similarities {:.2f} seconds'.format(\
time.time() - start_time), flush=True)
def find_pair_urls_sequential(args, lshcache, url_doc):
start_time = time.time()
f_out = open(args.output, 'wb')
deduped, counter = 0, 0
for b in lshcache.bins:
for bucket_id in b:
if len(b[bucket_id]) <= 1:
continue
bucket_urls = b[bucket_id].copy()
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
url_pairs_to_remove(args, bucket_urls, url_doc)
deduped += deduped_local_sub
counter += counter_local_sub
write_remove_urls_list(remove_urls_list_sub, f_out)
if counter % 10000 == 0:
print(' [write]> processed {} documents in {:.2f} '
'seconds and deduped {} documents ...'.format(counter,
time.time() - start_time, deduped),
flush=True)
f_out.close()
print(' [write]> processed {} documents in {:.2f} '
'seconds and deduped {} documents ...'.format(counter,
time.time() - start_time, deduped),
flush=True)
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1234, help='Random seed used for python, numpy')
parser.add_argument('--inputs', nargs = '*', default=None, help = \
'Pairwise list of the input files and keys, '
'e.g. --inputs cc.json cc_id news.json news_id')
parser.add_argument('--load-fingerprints',
nargs='*',
default=None,
help='Load fingerprints from a list of pickle files,'
' e.g. cc.pkl news.pkl')
parser.add_argument('--save-fingerprints', type=str, default=None, help='Save the fingerprints of the inputs.')
parser.add_argument('--output',
type=str,
default=None,
help='Output file name that consists of all ids'
' with matching similarities')
parser.add_argument('--jaccard', type=str, default='union',
choices=['union', 'min', 'max'], help='Jaccard'\
' similarity computation')
parser.add_argument('--heuristic-iter',
type=int,
default=1,
help='Number of iterations to run the heuristics'
': use -1 for exact')
parser.add_argument('--num-bands', type=int, default=10, help='Number of bands to use in cache')
parser.add_argument('--num-seeds',
type=int,
default=100,
help='Number of seeds to use for minhash. Note that'
' this value should be divisible by num-bands')
parser.add_argument('--jaccard-parallel',
action='store_true',
help='Use this to process large number of documents.')
args = parser.parse_args()
print('finding possible duplicate content ...')
# set seed and get an array of seeds of 100 integers
np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=args.num_seeds)
# initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
url_doc = {}
# load fingerprints from pickle file if needed
if args.load_fingerprints is not None:
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
print("Loading fingerprints from pickle file {}".format(fp_file_name), flush=True)
fp = open(fp_file_name, "rb")
if count_fp == 0:
# assign directory for the first pkl
lshcache = pickle.load(fp)
url_doc = pickle.load(fp)
else:
# append these to lshcache and url_doc
local_lshcache = pickle.load(fp)
local_url_doc = pickle.load(fp)
for url in local_lshcache.fingerprints.keys():
url_doc[url] = local_url_doc[url]
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
fp.close()
counter = 0
start_time = time.time()
# compute finger prints of the inputs if any
# input file and the key to use as id
if args.inputs is not None:
print("Computing fingerprints", flush=True)
assert len(args.inputs) % 2 == 0
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
print(' document processing {} with key {}'.format(input_file, key), flush=True)
# compute fingerprints in parallel
num_workers = 40
pool = multiprocessing.Pool(num_workers)
fin = open(input_file, 'r', encoding='utf-8')
compute_fingerprint_partial = partial(compute_fingerprint, key=key)
compute_fingerprint_iter = pool.imap(compute_fingerprint_partial, fin, 512)
# traverse all the texts and add fingerprints
for url, text, fingerprint, flag in compute_fingerprint_iter:
counter += 1
if flag:
url_doc[url] = text
lshcache.add_fingerprint(fingerprint, url)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
fin.close()
pool.close()
pool.join()
# Save the fingerprints if needed
if args.save_fingerprints is not None:
print("Saving fingerprints to pickle file {}".format(args.save_fingerprints), flush=True)
with open(args.save_fingerprints, 'wb') as f_save:
pickle.dump(lshcache, f_save)
pickle.dump(url_doc, f_save)
# compute jaccard index of the input texts and write to file if needed
if args.output is not None:
print("Compute jaccard similarity", flush=True)
if args.jaccard_parallel:
find_pair_urls_parallel(args, lshcache, url_doc)
else:
find_pair_urls_sequential(args, lshcache, url_doc)
print('done :-)')

View File

@ -1,305 +0,0 @@
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import os
import sys
from io import open
import regex as re
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
from cached_path import cached_path
resolved_vocab_file = cached_path(vocab_file)
resolved_merges_file = cached_path(merges_file)
except EnvironmentError:
logger.error("Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path, vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except BaseException:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning("Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len))
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file

View File

@ -1,85 +0,0 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
import time
if __name__ == '__main__':
print('grouping duplicate urls ...')
input = sys.argv[1]
output = sys.argv[2]
if len(sys.argv) > 3:
jaccard_similarity_threshold = float(sys.argv[3])
else:
jaccard_similarity_threshold = 0.7
url_to_index = {}
index_to_urls = []
counter = 0
start_time = time.time()
with open(input, 'r') as f:
for line in f:
counter += 1
myjson = json.loads(line)
urls = []
for main_url in myjson.keys():
urls.append(main_url)
for value in myjson[main_url]:
for other_url, js in value.items():
if js >= jaccard_similarity_threshold:
urls.append(other_url)
current_index = -1
other_indices = set()
for url in urls:
if url in url_to_index:
if current_index == -1:
current_index = url_to_index[url]
elif current_index != url_to_index[url]:
other_indices.add(url_to_index[url])
if current_index == -1:
current_index = len(index_to_urls)
index_to_urls.append(set())
for url in urls:
url_to_index[url] = current_index
index_to_urls[current_index].add(url)
for index in other_indices:
for url in index_to_urls[index]:
index_to_urls[current_index].add(url)
url_to_index[url] = current_index
index_to_urls[index] = None
if counter % 100000 == 0:
print(' > processed {} lines in {} seconds ...'.format(counter, time.time() - start_time))
total_remove = 0
total_remain = 0
for urls in index_to_urls:
if urls is not None:
if len(urls) > 1:
total_remove += (len(urls) - 1)
total_remain += 1
print('out of {} urls, only {} are unique and {} should be removed'.format(total_remove + total_remain,
total_remain, total_remove))
with open(output, 'wb') as f:
for i, urls in enumerate(index_to_urls):
if urls is not None:
if len(urls) > 1:
myjson = json.dumps({str(i): list(urls)}, ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))

View File

@ -1,64 +0,0 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
import time
if __name__ == '__main__':
url_filename = sys.argv[1]
data_filename = sys.argv[2]
output_filename = sys.argv[3]
urls = set()
with open(url_filename, 'r') as f:
for line in f:
myjson = json.loads(line)
for key in myjson:
this_urls = myjson[key]
for i in range(1, len(this_urls)):
urls.add(this_urls[i])
print('will be removing {} urls'.format(len(urls)), flush=True)
written_docs = 0
removed_docs = 0
removed_chars = 0
start_time = time.time()
with open(output_filename, 'wb') as fout:
with open(data_filename, 'r') as fin:
for line in fin:
try:
myjson = json.loads(line)
url = myjson['url']
if url in urls:
print('removing', myjson)
removed_docs += 1
removed_chars += len(myjson['text'])
continue
myjson = json.dumps(myjson, ensure_ascii=False)
fout.write(myjson.encode('utf-8'))
fout.write('\n'.encode('utf-8'))
written_docs += 1
if written_docs % 10000 == 0:
print(' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'.format(time.time() - start_time, written_docs, removed_docs,
removed_chars))
except Exception as e:
print('[SKIPPING]', line, e)
print(' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'.format(time.time() - start_time, written_docs, removed_docs, removed_chars))
print('done :-)')

View File

@ -1,36 +0,0 @@
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('..')
from gpt2_tokenization import GPT2Tokenizer
class Tokenizer:
def __init__(self, cache_dir=None):
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
self.tokenizer.max_len = int(1e12)
self.eod_token = self.tokenizer.encoder['<|endoftext|>']
assert self.eod_token < 65535, 'vocab size will not fit in uint16'
print('> GPT2 tokenizer with {} vocab size and eod token {} ...'.format(len(self.tokenizer.encoder),
self.eod_token))
def tokenize_document(self, document):
tokens = self.tokenizer.encode(document)
tokens.append(self.eod_token)
return tokens

View File

@ -1,347 +0,0 @@
# Code taken in large part from https://github.com/jcpeterson/openwebtext
from __future__ import print_function
import argparse
import io
import json
import multiprocessing as mpl
import os
import os.path as op
import sqlite3
import tarfile
import time
import warnings
from glob import glob
from hashlib import sha256
import tldextract
from scrapers import bs4_scraper, newspaper_scraper, raw_scraper
# for backward compatibility
from six.moves.urllib.request import urlopen
from tqdm import tqdm
from utils import chunks, extract_month, linecount, mkdir
parser = argparse.ArgumentParser()
parser.add_argument("url_file", type=str)
parser.add_argument(
"--save_uncompressed",
action="store_true",
default=False,
help="whether to save the raw txt files to disk",
)
parser.add_argument(
"--output",
type=str,
default='raw.json',
help="where to save the output json",
)
parser.add_argument(
"--output_dir",
type=str,
default="scraped",
help="which folder in the working directory to use for output",
)
parser.add_argument(
"--n_procs",
type=int,
default=10,
help="how many processes (cores) to use for parallel scraping",
)
parser.add_argument(
"--timeout",
type=int,
default=-1,
help="maximum scrape time for a single URL; -1 means no limit",
)
parser.add_argument(
"--max_urls",
type=int,
default=-1,
help="maximum # of URLs to scrape; mostly for debugging",
)
parser.add_argument(
"--chunk_size",
type=int,
default=100,
help="how many URLs to scrape before saving to archive",
)
parser.add_argument(
"--scraper",
type=str,
default="newspaper",
choices=["raw", "bs4", "newspaper"],
help="which text/content scraper to use; raw is html",
)
parser.add_argument(
"--compress",
action="store_true",
default=False,
help="whether to output scraped content as compressed archives",
)
parser.add_argument(
"--compress_fmt",
type=str,
default="xz",
choices=["xz", "bz2", "gz"],
help="which archive format to use",
)
parser.add_argument(
"--scraper_memoize",
action="store_true",
default=False,
help="whether to use cache for newspaper",
)
parser.add_argument(
"--show_warnings",
action="store_true",
default=False,
help="whether to show warnings in general during scraping",
)
parser.add_argument(
"--sqlite_meta",
action="store_true",
default=True,
help="whether to use sqlite for storing meta. if false, json will be used instead",
)
args = parser.parse_args()
if not args.show_warnings:
# avoid lots of datetime warnings
warnings.filterwarnings("ignore")
def load_urls(fh, max_urls=-1):
url_entries = enumerate(fh)
if max_urls != -1:
url_entries = list(url_entries)[:max_urls]
return url_entries
def vet_link(link):
# check if server responds with non-200 status code or link points to a
# non-html file
link_type, link_status = "", -1
try:
info = urlopen(link)
link_type = info.headers["Content-Type"]
link_status = info.status
except:
pass
# we want "text/html" only!
is_good_link = False
if "text/html" in link_type and link_status == 200:
is_good_link = True
return is_good_link, link_type
def download(url_entry,
scraper=args.scraper,
save_uncompressed=args.save_uncompressed,
memoize=args.scraper_memoize,
arch_meta=not args.sqlite_meta):
uid, url = url_entry
url = url.strip()
fid = "{:07d}-{}".format(uid, sha256(url.encode()).hexdigest())
data_dir = mkdir(op.join(args.output_dir, "data"))
text_fp = op.join(data_dir, "{}.txt".format(fid))
if arch_meta:
meta_dir = mkdir(op.join(args.output_dir, "meta"))
meta_fp = op.join(meta_dir, "{}.json".format(fid))
# already downloaded!
if op.exists(text_fp):
return
# is_good_link, link_type = vet_link(url)
# if not is_good_link:
# return
if scraper == "bs4":
scrape = bs4_scraper
elif scraper == "newspaper":
scrape = newspaper_scraper
elif scraper == "raw":
scrape = raw_scraper
text, meta = scrape(url, memoize)
ext = tldextract.extract(url)
domain = '.'.join([x for x in ext if x])
meta["domain"] = domain
if text is None or text.strip() == "":
return ("", meta, fid, uid)
if save_uncompressed:
with open(text_fp, "w") as out:
out.write(text)
if arch_meta:
with open(meta_fp, "w") as out:
json.dump(meta, out)
return (text, meta, fid, uid)
def archive_chunk(cid, cdata, out_dir, fmt, arch_meta):
mkdir(out_dir)
texts, metas, fids, uids = zip(*cdata)
data_tar = op.join(out_dir, "{}_data.{}".format(cid, fmt))
if arch_meta:
meta_tar = op.join(out_dir, "{}_meta.{}".format(cid, fmt))
tar_fps, texts, exts = [data_tar, meta_tar], [texts, metas], ["txt", "json"]
else:
tar_fps, texts, exts = [data_tar], [texts], ["txt"]
doc_count = 0
docs_counted = False
for tar_fp, txts, ext in zip(tar_fps, texts, exts):
with tarfile.open(tar_fp, "w:" + fmt) as tar:
for f, fid in zip(txts, fids):
if f == "":
continue
else:
if not docs_counted:
doc_count += 1
if ext == "json":
f = json.dumps(f)
f = f.encode("utf-8")
t = tarfile.TarInfo("{}.{}".format(fid, ext))
t.size = len(f)
tar.addfile(t, io.BytesIO(f))
docs_counted = True
return doc_count
def load_state(url_file):
ckptfile = url_file + '.ckpt'
if op.exists(ckptfile):
with open(ckptfile) as fp:
r = fp.read()
if r == '':
return 0
else:
return int(r)
else:
return 0
def save_state(url_file, cid):
ckptfile = url_file + '.ckpt'
with open(ckptfile, 'w') as fp:
fp.write(str(cid))
def sqlite_conn():
conn = sqlite3.connect('metadata.db')
conn.execute('''
CREATE TABLE IF NOT EXISTS metadata (
fid char(64) not null primary key,
url varchar(2048) not null,
domain varchar(255) not null,
word_count int null,
elapsed int null,
scraper varchar(255) not null,
success boolean not null
);
''')
conn.execute('''
CREATE INDEX IF NOT EXISTS ix_meta_url ON metadata(url);
''')
conn.execute('''
CREATE INDEX IF NOT EXISTS ix_meta_domain ON metadata(domain);
''')
return conn
if __name__ == "__main__":
if args.sqlite_meta:
conn = sqlite_conn()
cur = conn.cursor()
start_elem = load_state(args.url_file)
start_chnk = start_elem // args.chunk_size
f_json = open(args.output, "w")
# URLs we haven't scraped yet (if first run, all URLs in file)
with open(args.url_file) as fh:
url_entries = load_urls(fh, args.max_urls)
pool = mpl.Pool(args.n_procs)
total = linecount(args.url_file) // args.chunk_size
print('Total chunks: ', total)
chunk_iterator = tqdm(enumerate(chunks(url_entries, args.chunk_size, start_elem)), total=total)
# display already-downloaded chunks on progress bar
chunk_iterator.update(start_chnk)
# process one "chunk" of args.chunk_size URLs at a time
for i, chunk in chunk_iterator:
cid = start_chnk + i + 1
tqdm.write("Downloading chunk {}".format(cid))
t1 = time.time()
if args.timeout > 0:
# imap as iterator allows .next() w/ timeout.
# ordered version doesn't seem to work correctly.
# for some reason, you CANNOT track j or chunk[j] in the loop,
# so don't add anything else to the loop below!
# confusingly, chunksize below is unrelated to our chunk_size
chunk_iter = pool.imap_unordered(download, chunk, chunksize=1)
cdata = []
for j in range(len(chunk)):
try:
result = chunk_iter.next(timeout=args.timeout)
cdata.append(result)
except mpl.TimeoutError:
tqdm.write(" --- Timeout Error --- ")
else:
cdata = list(pool.imap(download, chunk, chunksize=1))
tqdm.write("{} / {} downloads timed out".format(len(chunk) - len(cdata), len(chunk)))
tqdm.write("Chunk time: {} seconds".format(time.time() - t1))
# write metadata to sqlite
if args.sqlite_meta:
for text, meta, fid, _ in filter(lambda x: x, cdata):
if text:
params = (fid, meta["url"], meta["domain"], meta["elapsed"], meta["word_count"],
meta["scraper"], True)
else:
params = (fid, meta["url"], meta["domain"], None, None, meta["scraper"], False)
cur.execute(
"insert or ignore into metadata (fid, url, domain, elapsed, word_count, scraper, success) values (?, ?, ?, ?, ?, ?, ?)",
params)
conn.commit()
dump_chunk = []
for text, meta, fid, _ in filter(lambda x: x, cdata):
if text:
line_json = {"text": text, "url": meta["url"]}
dump_chunk.append(json.dumps(line_json) + '\n')
f_json.writelines(dump_chunk)
# archive and save this chunk to file
if args.compress:
tqdm.write("Compressing...")
t2 = time.time()
count = archive_chunk(cid, cdata, args.output_dir, args.compress_fmt, not args.sqlite_meta)
tqdm.write("Archive created in {} seconds".format(time.time() - t2))
tqdm.write("{} out of {} URLs yielded content\n".format(len(list(filter(lambda x: x and x[0], cdata))),
len(chunk)))
save_state(args.url_file, cid * args.chunk_size)
f_json.close()
print("Done!")

View File

@ -1,58 +0,0 @@
import hashlib
import multiprocessing as mp
import os
import traceback
import newspaper
import tldextract
import tqdm
from filter import should_exclude
hash = hashlib.sha256
try:
os.mkdir('data')
except FileExistsError:
pass
def dl(url):
url = url.strip()
if should_exclude(url):
return
ext = tldextract.extract(url)
domain = '.'.join([x for x in ext if x])
fname = 'data/{}-{}.txt'.format(domain, hash(url.encode()).hexdigest())
if os.path.isfile(fname):
return
# print('Downloading', url)
try:
article = newspaper.Article(url, fetch_images=False)
article.download()
article.parse()
except newspaper.article.ArticleException:
# print('Dead link:', url)
return
# traceback.print_exc()
text = article.text
if text.strip() == '':
# print('Empty')
return
with open(fname, 'w') as out:
out.write(text)
if __name__ == '__main__':
p = mp.Pool(100) # num of download threads
with open('urls.txt') as fh:
urls = list(fh)
list(tqdm.tqdm(p.imap(dl, urls), total=len(urls)))
print('Done!')

View File

@ -1,110 +0,0 @@
import re
import tldextract
import tqdm
from utils import linecount
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
url_regex = re.compile(
r'^(?:http)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$',
re.IGNORECASE)
# domains that aren't scraper friendly. do not include subdomains!
exclude_domains = set([
# image & video hosting sites
'imgur.com',
'redd.it',
'instagram.com',
'discord.gg',
'gfycat.com',
'giphy.com',
'reddituploads.com',
'redditmedia.com',
'twimg.com',
'sli.mg',
'magaimg.net',
'flickr.com',
'imgflip.com',
'youtube.com',
'youtu.be',
'youtubedoubler.com',
'vimeo.com',
'twitch.tv',
'streamable.com',
'bandcamp.com',
'soundcloud.com',
# not scraper friendly
'reddit.com',
'gyazo.com',
'github.com',
'xkcd.com',
'twitter.com',
'spotify.com',
'itunes.apple.com',
'facebook.com',
'gunprime.com',
'strawpoll.me',
'voyagefusion.com',
'rollingstone.com',
'google.com',
'timeanddate.com',
'walmart.com',
'roanoke.com',
'spotrac.com',
# original paper excluded wikipedia
'wikipedia.org',
# lots of top posts for this one
'battleforthenet.com',
])
exclude_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.gifv', '.pdf', '.mp4', '.mp3', '.ogv', '.webm', '.doc',
'.docx', '.log', '.csv', '.dat', '.iso', '.bin', '.exe', '.apk', '.jar', '.app', '.ppt', '.pps',
'.pptx', '.xml', '.gz', '.xz', '.bz2', '.tgz', '.tar', '.zip', '.wma', '.mov', '.wmv', '.3gp',
'.svg', '.rar', '.wav', '.avi', '.7z')
def should_exclude(url):
ext = tldextract.extract(url)
domain = '.'.join([x for x in ext if x])
basedomain = '.'.join(ext[-2:])
# Ignore non-URLs
if len(url) <= 8 or ' ' in url or re.match(url_regex, url) is None:
return True
# Ignore excluded domains
if basedomain in exclude_domains or domain in exclude_domains:
return True
# Ignore case-insensitive matches for excluded extensions
if url.lower().split('?')[0].endswith(exclude_extensions):
return True
return False
if __name__ == '__main__':
url_file = 'urls.txt'
filtered_file = 'urls-filtered.txt'
with open(url_file) as urls, open(filtered_file, 'w') as out:
url_len = linecount(url_file)
print("URL file is", url_len, "URLs long.")
url_set = set()
for line in tqdm.tqdm(urls, total=url_len):
if len(line.strip()) == 0:
continue # Skip whitespace-only lines
line = line.strip().split()[0] # Drop any components following whitespace
if should_exclude(line):
continue
url_set.add(line)
for line in tqdm.tqdm(url_set):
out.write(line + '\n')

View File

@ -1,32 +0,0 @@
import datetime
import praw
import psaw
import tqdm
api = psaw.PushshiftAPI()
# all posts until the end of 2017
end_time = int(datetime.datetime(2018, 1, 1).timestamp())
query = api.search_submissions(before=end_time,
filter=['url', 'score'],
sort='desc',
score='>2',
is_self=False,
over_18=False)
with tqdm.tqdm() as pbar:
# download links from submissions
with open('urls.txt', 'w') as fh:
for subm in query:
url = subm.url
# weird issue with psaw/pushshift that breaks score=">2"
if subm.score < 3:
continue
#print(subm.score)
# pbar.write(str(datetime.datetime.fromtimestamp(subm.created_utc)))
pbar.update(1)
fh.write(url + '\n')
fh.flush()

View File

@ -1,121 +0,0 @@
# Code taken in large part from https://github.com/jcpeterson/openwebtext
import time
import unicodedata
import bs4
import newspaper
from filter import should_exclude
from htmlmin import minify
from lxml.html.clean import Cleaner
def find_and_filter_tag(tag, soup):
"""tag specific filter logic"""
candidates = soup.find_all(tag)
candidates = [unicodedata.normalize("NFKD", x.string) for x in candidates if x.string is not None]
if tag == "p":
candidates = [y.strip() for y in candidates if len(y.split(" ")) >= 4]
count = sum(len(y.split(" ")) for y in candidates)
else:
raise NotImplementedError
return (candidates, count)
def raw_scraper(url, memoize):
t1 = time.time()
if should_exclude(url):
# heuristic to make downloading faster
return None, {
"url": url,
"scraper": "raw",
}
try:
cleaner = Cleaner()
cleaner.javascript = True
cleaner.style = True
article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize)
article.download()
html = minify(article.html)
html = cleaner.clean_html(html)
article.parse()
except:
return None, {
"url": url,
"scraper": "raw",
}
if article.text == "":
return None, {
"url": url,
"scraper": "raw",
}
metadata = {"url": url, "elapsed": time.time() - t1, "scraper": "raw"}
return html, metadata
def newspaper_scraper(url, memoize):
t1 = time.time()
if should_exclude(url):
# heuristic to make downloading faster
return None, {
"url": url,
"scraper": "newspaper",
}
try:
article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize)
article.download()
article.parse()
text = article.text
count = len(text.split())
except:
return None, {
"url": url,
"scraper": "newspaper",
}
metadata = {
"url": url,
"word_count": count,
"elapsed": time.time() - t1,
"scraper": "newspaper",
}
return text, metadata
def bs4_scraper(url, memoize):
t1 = time.time()
if should_exclude(url):
# heuristic to make downloading faster
return None, {
"url": url,
"scraper": "bs4",
}
try:
article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize)
article.download()
html = article.html
soup = bs4.BeautifulSoup(html, "lxml")
text, count = find_and_filter_tag("p", soup)
# DDB: keep text as a single string for consistency with
# newspaper_scraper
text = " ".join(text)
except:
return None, {
"url": url,
"scraper": "bs4",
}
metadata = {
"url": url,
"word_count": count,
"elapsed": time.time() - t1,
"scraper": "bs4",
}
return text, metadata

View File

@ -1,62 +0,0 @@
# Code taken in large part from https://github.com/jcpeterson/openwebtext
import collections
import os
import os.path as op
import re
import tarfile
def extract_month(url_file_name):
month_re = r"(RS_.*2\d{3}-\d{2})"
month = op.split(url_file_name)[-1]
month = re.match(month_re, month).group()
return month
def chunks(l, n, s=0):
"""Yield successive n-sized chunks from l, skipping the first s chunks."""
if isinstance(l, collections.Iterable):
chnk = []
for i, elem in enumerate(l):
if i < s:
continue
chnk.append(elem)
if len(chnk) == n:
yield chnk
chnk = []
if len(chnk) != 0:
yield chnk
else:
for i in range(s, len(l), n):
yield l[i:i + n]
def extract_archive(archive_fp, outdir="."):
with tarfile.open(archive_fp, "r") as tar:
tar.extractall(outdir)
return outdir
def mkdir(fp):
try:
os.makedirs(fp)
except FileExistsError:
pass
return fp
def linecount(filename):
f = open(filename, 'rb')
lines = 0
buf_size = 1024 * 1024
read_f = f.raw.read
buf = read_f(buf_size)
while buf:
lines += buf.count(b'\n')
buf = read_f(buf_size)
return lines

View File

@ -1,143 +0,0 @@
import contextlib
import os
import torch
from dataset.webtext import WebtextDataset
from titans.loss.lm_loss import GPTLMLoss
import colossalai
import colossalai.utils as utils
from colossalai import nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import LinearWarmupLR
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.trainer import Trainer, hooks
from colossalai.utils import is_using_pp
from colossalai.utils.timer import MultiTimer
from colossalai.zero.init_ctx import ZeroInitContext
def calc_local_model_size(model: torch.nn.Module):
numel_per_device = 0
for p in model.parameters():
numel_per_device += p.numel()
return numel_per_device
def main():
parser = colossalai.get_default_parser()
parser.add_argument('--from_torch', default=False, action='store_true')
args = parser.parse_args()
disable_existing_loggers()
if args.from_torch:
colossalai.launch_from_torch(config=args.config)
else:
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
logger = get_dist_logger()
logger.info('Build data loader', ranks=[0])
train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN)
train_dataloader = utils.get_dataloader(train_ds,
seed=42,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
shuffle=True,
drop_last=True)
logger.info('Build model', ranks=[0])
use_pipeline = is_using_pp()
use_interleaved = hasattr(gpc.config.model, 'num_chunks')
num_chunks = getattr(gpc.config.model, 'num_chunks', 1)
use_zero3 = hasattr(gpc.config, 'zero')
if not use_pipeline:
ctx = contextlib.nullcontext()
if use_zero3:
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True)
with ctx:
model = gpc.config.model.pop('type')(**gpc.config.model)
else:
pipelinable = PipelinableContext()
with pipelinable:
model = gpc.config.model.pop('type')(**gpc.config.model)
def mask_function(attention_mask=None):
if attention_mask is not None:
batch_size = gpc.config.BATCH_SIZE // gpc.config.NUM_MICRO_BATCHES
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = col_nn.partition_batch(attention_mask)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0
return attention_mask
# GPT2_small exec_seq
# (lyl)TODO: The exec_seq for gpt3 will be added here and to_layer_list should be more friendly to use.
exec_seq = ['embed', mask_function, 'blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'blocks.4', 'blocks.5', (mask_function, "front"), \
'blocks.6', 'blocks.7', 'blocks.8', 'blocks.9', 'blocks.10', 'blocks.11', 'norm', 'head']
pipelinable.to_layer_list(exec_seq)
ctx = contextlib.nullcontext()
# (lyl)TODO: Zero context and pipelinable context should be integrated into one context.
if use_zero3:
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True)
with ctx:
model = pipelinable.partition(num_chunks, gpc.pipeline_parallel_size,
gpc.get_local_rank(ParallelMode.PIPELINE))
if use_zero3:
numel = ctx.model_numel_tensor.item()
else:
numel = calc_local_model_size(model)
tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)
criterion = getattr(gpc.config, 'loss_fn', None)
if criterion is not None:
criterion = criterion.type()
else:
criterion = GPTLMLoss()
logger.info('Build optimizer', ranks=[0])
optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer)
lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,
optimizer,
criterion,
train_dataloader=train_dataloader,
lr_scheduler=lr_scheduler)
global_batch_size = gpc.config.BATCH_SIZE * \
gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
timier = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timier)
hook_list = [
hooks.LossHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
hooks.LogMetricByEpochHook(logger),
hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
hooks.LogMetricByStepHook(),
hooks.LogMemoryByEpochHook(logger),
]
trainer.fit(train_dataloader=train_dataloader,
epochs=gpc.config.NUM_EPOCHS,
test_interval=1,
hooks=hook_list,
display_progress=True,
return_output_label=False)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,161 @@
from functools import partial
from time import time
import psutil
import torch
import torch.nn as nn
from packaging import version
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import GPT2Config, GPT2LMHeadModel
class GPTLMModel(nn.Module):
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50257,
checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size))
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def gpt2_medium(checkpoint=False):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
def gpt2_xl(checkpoint=True):
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
def gpt2_10b(checkpoint=True):
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2
def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def main():
BATCH_SIZE = 8
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
PLACEMENT_POLICY = 'auto'
disable_existing_loggers()
colossalai.launch_from_torch(config={})
pg = ProcessGroup()
logger = get_dist_logger()
logger.info(get_mem_info(), ranks=[0])
# build GPT model
with ColoInitContext(device=get_current_device()):
model = gpt2_medium(checkpoint=True)
numel = sum([p.numel() for p in model.parameters()])
logger.info(f'Model numel: {numel}', ranks=[0])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=PLACEMENT_POLICY,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
model = ZeroDDP(model, gemini_manager)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
# build criterion
criterion = GPTLMLoss()
# optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
model.train()
for n in range(NUM_STEPS):
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
start = time()
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
optimizer.backward(loss)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
optimizer.step()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
step_time = time() - start
logger.info(
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
ranks=[0])
if __name__ == '__main__':
main()

View File

@ -22,6 +22,9 @@ The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI)
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
## Our Modifications
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
## Quick Start
You can launch training by using the following bash script