mirror of https://github.com/hpcaitech/ColossalAI
[example] simplify the GPT2 huggingface example (#1826)
parent
cd5a0d56fa
commit
b1263d32ba
|
@ -1,242 +1,16 @@
|
|||
# Run GPT With Colossal-AI
|
||||
|
||||
## Overview
|
||||
This example shows how to use ColossalAI to run huggingface GPT training in distributed manners.
|
||||
|
||||
In Colossal-AI, there are many ways to run GPT in a distributed manner. The `train_gpt.py` script runs training with the specific configuration scripts in `gpt2_configs/` for different parallelisms of GPT-2 . We have provided some example configuration files of GPT-2 and you can modify them to adapt to your own use.
|
||||
## GPT
|
||||
We use the huggingface transformers GPT2 model. The input data is randonly generated.
|
||||
|
||||
## How to Prepare Webtext Dataset
|
||||
## Our Modifications
|
||||
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
|
||||
|
||||
We do not host any datasets for GPT or BERT training, however, we provide a detailed guide on how to prepare the dataset so that our results may be reproduced.
|
||||
|
||||
### Overview
|
||||
|
||||
We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/openwebtext) library by [jcpeterson](https://github.com/jcpeterson/openwebtext) and [eukaryote31's](https://github.com/eukaryote31/openwebtext) work to download urls to different web pages. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in following section.
|
||||
|
||||
### Install necessary packages
|
||||
|
||||
**Note: LSH requires GCC's early version. We have tested that version 9.3.0 works, but version 10.3.0 is not.**
|
||||
## Quick Start
|
||||
You can launch training by using the following bash script
|
||||
|
||||
```bash
|
||||
pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract cached-path
|
||||
git clone https://github.com/mattilyra/LSH.git
|
||||
cd LSH
|
||||
python setup.py install
|
||||
pip install -r requirements.txt
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
If you couldn't install it successfully, you may try to replace the `cMinhash.cpp` in `LSH/lsh` with ours, which is provided in `tools/lsh/cMinhash.cpp`.
|
||||
|
||||
### Download Data
|
||||
|
||||
1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ).
|
||||
|
||||
2. Unzip the zip file and you will get a folder `URLs` which consists of many txt files including urls.
|
||||
|
||||
3. Remove blacklisted URLs.
|
||||
|
||||
*We appreciate Megatron-LM for making the data preprocessing code public. We have forked Megatron-LM and fixed some bugs. For your convenience, we have collated the needed files in `tools/Megatron`. Click [here](https://github.com/NVIDIA/Megatron-LM.git) to check the source code of Megatron-LM.*
|
||||
|
||||
```bash
|
||||
cd path/to/tools
|
||||
python Megatron/blacklist_urls.py <path/to/URLs> <path/to/clean_urls.txt>
|
||||
```
|
||||
|
||||
4. Download the content from the clean urls and merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`.
|
||||
|
||||
*We have forked and modified [openwebtext](https://github.com/yet-another-account/openwebtext) as there are some bugs in it. For your convenience, we provide our modified version in `tools/download`.*
|
||||
|
||||
```bash
|
||||
python download/download.py <path/to/clean_urls.txt> --n_procs 50 --output <path/to/raw.json>
|
||||
```
|
||||
|
||||
### Prepare Data for GPT Training
|
||||
|
||||
1. Perform ftfy, English detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
|
||||
|
||||
```bash
|
||||
python Megatron/cleanup_dataset.py <path/to/raw.json> <path/to/clean.json>
|
||||
```
|
||||
|
||||
Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using `cleanup_fix_dataset.py`. More details can be found by running `python cleanup_fix_dataset.py --help`.
|
||||
|
||||
2. Using LSH, find possible duplicates and store them in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by `python find_duplicate.py --help`.
|
||||
|
||||
```bash
|
||||
python Megatron/find_duplicates.py --inputs <path/to/clean.json> url --output <path/to/process_stage_one.json>
|
||||
```
|
||||
|
||||
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
|
||||
|
||||
```bash
|
||||
python Megatron/group_duplicate_url.py <path/to/process_stage_one.json> <path/to/process_stage_two.json>
|
||||
```
|
||||
|
||||
4. Remove similar documents that were detected in the last step. The `dedup.json` is the data after deduplication.
|
||||
|
||||
```bash
|
||||
python Megatron/remove_group_duplicates.py <path/to/process_stage_two.json> <path/to/clean.json> <path/to/dedup.json>
|
||||
```
|
||||
|
||||
5. shuffle the dataset.
|
||||
|
||||
```bash
|
||||
shuf <path/to/dedup.json> -o <path/to/train_data.json>
|
||||
```
|
||||
|
||||
## How to Prepare Yuan Dataset
|
||||
|
||||
### Overview
|
||||
|
||||
Yuan dataset is a large scale Chinese dataset with 1TB high quality texts proposed by Inspur. You can apply on https://air.inspur.com/home to get access to the dataset. We downloaded and loaded all downloaded content according to the procedure described in following section.
|
||||
|
||||
### Download
|
||||
|
||||
The dataset can be according to the website once your application is approved.
|
||||
|
||||
You also need to download the vocab file from https://github.com/Shawn-Inspur/Yuan-1.0/blob/main/src/vocab.txt
|
||||
|
||||
The final data dir should be organized as:
|
||||
|
||||
```
|
||||
|--dataset
|
||||
| |--001.txt
|
||||
| |--002.txt
|
||||
| |--...
|
||||
|--vocab.txt
|
||||
```
|
||||
|
||||
### Process & Load
|
||||
|
||||
Before you run the code, you should replace line 44 in train_gpt.py with
|
||||
|
||||
```
|
||||
import dataset.yuan import YuanDataset
|
||||
train_ds = YuanDataset(os.environ['DATA'], vocab_path='/path/to/data/vocab.txt'seq_len=gpc.config.SEQ_LEN)
|
||||
```
|
||||
|
||||
Then you can run model following the Usage section. The dataset will be processed when you run it for the first time, and save the cache. Then the data can be loaded automatically.
|
||||
|
||||
## **Usage**
|
||||
|
||||
```Bash
|
||||
#!/usr/bin/env sh
|
||||
export DATA=/path/to/train_data.json
|
||||
|
||||
colossalai run --nproc_per_node=<num_gpus> train_gpt.py --config=gpt2_configs/<config_file>
|
||||
```
|
||||
|
||||
You can copy it and save it as `run.sh`. Then use `bash ./run.sh` to run the script in your terminal.
|
||||
|
||||
Please modify `DATA`, `num_gpus` and `config_file` with the path to your dataset, the number of GPUs and the config file path, respectively.
|
||||
If you are going to train gpt3, just replace `gpt2_configs` with `gpt3_configs`.
|
||||
|
||||
## GPT-2
|
||||
|
||||
Here are the GPT-2 configs' default parameter:
|
||||
|
||||
| config | scale | GPU* | batch size | MiB of each GPU | TP | PP | DP |
|
||||
| ------------ | ----- | ---- | ----------- | --------------- | --- | --- | --- |
|
||||
| gpt2-vanilla | small | 1 | 1 | 6071 | 1 | 1 | 1 |
|
||||
| gpt2-vanilla | small | 2 | 1 | 6449*2 | 1 | 1 | 2 |
|
||||
| gpt2-1d | small | 2 | 1 | 5287*2 | 2 | 1 | 1 |
|
||||
| gpt2-2d | small | 4 | 1 | 4590*4 | 4 | 1 | 1 |
|
||||
| gpt-2.5d | small | 8 | 1 | 4815*8 | 8 | 1 | 1 |
|
||||
| gpt2-3d | small | 8 | 1 | 4901*8 | 8 | 1 | 1 |
|
||||
| gpt2-pp | small | 2 | 1 | 5877*2 | 1 | 2 | 1 |
|
||||
| gpt2-zero2 | small | 1 | 1 | 5459 | 1 | 1 | 1 |
|
||||
| gpt2-zero3 | small | 1 | 1 | 6577 | 1 | 1 | 1 |
|
||||
| gpt2-nvme | small | 1 | 1 | 5067 | 1 | 1 | 1 |
|
||||
| gpt2-pp1d | small | 8 | 8 | 5411*8 | 2 | 2 | 2 |
|
||||
|
||||
*\*Note: For GPUs, we use Nvidia A100 80G.*
|
||||
*\*Note: Results of ZeRO are outdated, we will update them soon.*
|
||||
|
||||
**We set** `TENSOR_PARALLEL` `PIPELINE_PARALLEL` **and** `DATA_PARALLEL` **as small as it can be to run every demo with the least number of GPUs.**
|
||||
|
||||
### **Modify the config file**
|
||||
|
||||
#### **General**
|
||||
|
||||
There are some **general rules** when modifying the config files.
|
||||
|
||||
```Plain%20Text
|
||||
TP denotes Tensor Parallel
|
||||
PP denotes Pipeline Parallel
|
||||
DP denotes Data Parallel
|
||||
|
||||
GPUS = TP * PP * DP
|
||||
Where DP is autoseted
|
||||
```
|
||||
|
||||
You can set the **batch size** and the **epoch** number by changing the number of
|
||||
`BATCH_SIZE` and `NUM_EPOCHS`, respectively. Then, we will introduce the config file of each mode.
|
||||
|
||||
Please note that `gpt2_zero3.py` has nothing but `BATCH_SIZE` and `NUM_EPOCHS` to change.
|
||||
|
||||
#### **Vanilla & Data Parallel**
|
||||
|
||||
`Vanilla` is the basic mode of GPT-2 with no parallelism at all. However, if you use more than 1 GPU and TP * PP < no. of GPUs, Colossal-AI will **set DP for you** **automatically**.
|
||||
|
||||
#### **1D, 2D, 2.5D, 3D**
|
||||
|
||||
In files `gpt2_1d.py, gpt2_2d.py, gpt2_2p5d.py, gpt2_3d.py`, there is a line:
|
||||
|
||||
```Python
|
||||
TENSOR_PARALLEL = 2
|
||||
```
|
||||
|
||||
You can modify it to use more tensor parallel, just with the general rules satisfied.
|
||||
In particular, `TENSOR_PARALLEL` should be a square number and cubic number for 2D and 3D,
|
||||
respectively, and `TENSOR_PARALLEL / DEPTH` should be a square number for 2.5D.
|
||||
|
||||
#### **Pipeline Parallel**
|
||||
|
||||
To use pipeline parallel training, you should install colossalai from the **latest** main branch.
|
||||
|
||||
In `gpt2_pp.py`, there are lines:
|
||||
|
||||
```Python
|
||||
# BATCH_SIZE / NUM_MICRO_BATCHES should be an integer
|
||||
NUM_MICRO_BATCHES = 1
|
||||
PIPELINE = 2
|
||||
```
|
||||
|
||||
#### **Pipeline + 1D + Data Parallel**
|
||||
|
||||
In `gpt2_pp1d.py`, we have
|
||||
|
||||
```Python
|
||||
BATCH_SIZE = 8
|
||||
NUM_EPOCHS = 60
|
||||
NUM_MICRO_BATCHES = 1
|
||||
HIDDEN_SIZE = 768
|
||||
PIPELINE = 2
|
||||
TENSOR_PARALLEL = 2
|
||||
MODE = '1d'
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
|
||||
```
|
||||
|
||||
We have introduced `BATCH_SIZE`, `NUM_EPOCHS`, `NUM_MICRO_BATCHES`, `PIPELINE`, `TENSOR_PARALLEL` as discussed above.
|
||||
`HIDDEN_SIZE` refers to the hidden dimension of the model, i.e. `gpt2_small` is 768.
|
||||
You can choose `None, '1d', '2d', '2.5d', '3d'` for `MODE`.
|
||||
|
||||
## GPT-3
|
||||
|
||||
GPT-3 is a really huge model, for which it seems not possible to train it with a little number of GPUs. Therefore, we choose some common sets of parameters instead of the smallest ones.
|
||||
|
||||
Here are our default parameters of GPT-3 configs:
|
||||
|
||||
| config | GPU* | batch size | TP | PP | DP |
|
||||
| -------------- | ---- | ---------- | --- | --- | --- |
|
||||
| gpt3_pp1d_min | 96 | 192 | 4 | 24 | 1 |
|
||||
| gpt3_pp1d | 128 | 192 | 4 | 32 | 1 |
|
||||
| gpt3_pp2d | 96 | 2*48 | 4 | 24 | 1 |
|
||||
| gpt3_pp2p5d | 96 | 2*48 | 4 | 24 | 1 |
|
||||
| gpt3_zero3_min | 64 | 3 | 1 | 1 | 64 |
|
||||
| gpt3_zero3 | 96 | 2 | 1 | 1 | 96 |
|
||||
|
||||
*\*Note: we use Nvidia A100 40G GPUs*
|
||||
*\*Note: Results of ZeRO are outdated, we will update them soon.*
|
||||
|
||||
In the figure above, the suffix `_min` means the set of hyper-parameters requires the least number of GPUs with the same mode.
|
||||
|
||||
GPT-3 and GPT-2 have the same set of hyper-parameters.
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from colossalai.registry import DATASETS
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
|
||||
@DATASETS.register_module
|
||||
class WebtextDataset(Dataset):
|
||||
|
||||
def __init__(self, path, seq_len=1024) -> None:
|
||||
super().__init__()
|
||||
root = os.path.dirname(path)
|
||||
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
|
||||
if os.path.isfile(encoded_data_cache_path):
|
||||
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
|
||||
if seq_len_ == seq_len:
|
||||
self.data = data
|
||||
self.attention_mask = attention_mask
|
||||
return
|
||||
raw_data = []
|
||||
with open(path) as f:
|
||||
for line in f.readlines():
|
||||
raw_data.append(json.loads(line)['text'])
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
|
||||
self.data = encoded_data['input_ids']
|
||||
self.attention_mask = encoded_data['attention_mask']
|
||||
torch.save((seq_len, self.data, self.attention_mask), encoded_data_cache_path)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index]
|
|
@ -1,329 +0,0 @@
|
|||
import collections
|
||||
import glob
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
|
||||
import jieba
|
||||
import six
|
||||
import torch
|
||||
from tools.tokenization_enc_dec import EncDecTokenizer
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.registry import DATASETS
|
||||
|
||||
try:
|
||||
import nltk
|
||||
|
||||
nltk_available = True
|
||||
except ImportError:
|
||||
nltk_available = False
|
||||
|
||||
jieba.setLogLevel(logging.INFO)
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def is_contain_chinese(check_str):
|
||||
for ch in check_str:
|
||||
if u'\u4e00' <= ch <= u'\u9fff':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Should be running on Python 3")
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
|
||||
def __init__(self, vocab, unk_token="<unk>", max_input_chars_per_word=200):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, token):
|
||||
|
||||
token = convert_to_unicode(token)
|
||||
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
return [self.unk_token]
|
||||
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if is_contain_chinese(substr):
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
else:
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
sub_tokens.append(self.unk_token)
|
||||
start += 1
|
||||
continue
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
return sub_tokens
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r", encoding='utf-8') as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
class EncDecTokenizer(object):
|
||||
|
||||
def __init__(self, vocab_file, max_len=None, max_sentinels=190):
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.encoder = load_vocab(vocab_file)
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.encoder)
|
||||
|
||||
self.translator = str.maketrans(" \n", "\u2582\u2583")
|
||||
|
||||
self.sentinel_list = [self.encoder['<s_{}>'.format(i)] for i in range(max_sentinels)]
|
||||
|
||||
self.en_vocab = {}
|
||||
for k, v in self.encoder.items():
|
||||
if is_contain_chinese(k):
|
||||
self.en_vocab[v] = False
|
||||
else:
|
||||
self.en_vocab[v] = True
|
||||
self.en_vocab[10] = False
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.encoder)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder)
|
||||
|
||||
@property
|
||||
def eod_id(self):
|
||||
return self.encoder[self.eod_token]
|
||||
|
||||
@property
|
||||
def pad_id(self):
|
||||
return self.encoder[self.pad_token]
|
||||
|
||||
@property
|
||||
def eod_token(self):
|
||||
return '<eod>'
|
||||
|
||||
@property
|
||||
def pad_token(self):
|
||||
return '<pad>'
|
||||
|
||||
def get_sentinel_num(self):
|
||||
return len(self.sentinel_list)
|
||||
|
||||
def get_sentinel_id(self, idx):
|
||||
return self.sentinel_list[idx]
|
||||
|
||||
def tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
output_tokens = []
|
||||
for x in jieba.cut(text, cut_all=False):
|
||||
x = x.translate(self.translator)
|
||||
output_tokens.extend(self.wordpiece_tokenizer.tokenize(x))
|
||||
|
||||
# print(output_tokens)
|
||||
|
||||
return output_tokens
|
||||
|
||||
def encode(self, text):
|
||||
output_tokens = [self.encoder[x] for x in self.tokenize(text)]
|
||||
|
||||
# filter space
|
||||
new_output_tokens = [output_tokens[0]]
|
||||
for i, x in enumerate(output_tokens[1:-1]):
|
||||
if x == 10:
|
||||
if self.en_vocab[output_tokens[i]] and self.en_vocab[output_tokens[i + 2]]:
|
||||
continue
|
||||
new_output_tokens.append(x)
|
||||
new_output_tokens.append(output_tokens[-1])
|
||||
|
||||
return new_output_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
new_tokens = []
|
||||
for i, x in enumerate(tokens[:-1]):
|
||||
if self.en_vocab[x] and self.en_vocab[tokens[i + 1]]:
|
||||
new_tokens.append(x)
|
||||
new_tokens.append(10)
|
||||
else:
|
||||
new_tokens.append(x)
|
||||
new_tokens.append(tokens[-1])
|
||||
|
||||
# text = ''.join([self.decoder[x] for x in new_tokens])
|
||||
# text = text.replace('\u2582', ' ').replace('\u2583', '\n')
|
||||
# return text
|
||||
return [self.decoder[x] for x in tokens]
|
||||
|
||||
|
||||
class IdentitySplitter(object):
|
||||
|
||||
@staticmethod
|
||||
def tokenize(*text):
|
||||
return text
|
||||
|
||||
|
||||
class Encoder(object):
|
||||
|
||||
def __init__(self, vocab_path, length, sentence_splitter):
|
||||
self.vocab_path = vocab_path
|
||||
self.length = length
|
||||
self.sentence_splitter = sentence_splitter
|
||||
self.tokenizer = EncDecTokenizer(os.path.join(self.vocab_path))
|
||||
self.splitter = IdentitySplitter()
|
||||
|
||||
def initializer(self):
|
||||
# Use Encoder class as a container for global data
|
||||
pass
|
||||
|
||||
def encode(self, line):
|
||||
# end with <eod>
|
||||
if len(line) > 20000:
|
||||
return None, 0
|
||||
if len(line) < 10:
|
||||
return None, 0
|
||||
data = line.strip().strip('<n>')
|
||||
data = data.replace("<n>", "\n")
|
||||
doc_ids = self.tokenizer.encode(data)
|
||||
doc_ids.append(self.tokenizer.eod_id)
|
||||
return doc_ids, len(line)
|
||||
|
||||
|
||||
@DATASETS.register_module
|
||||
class YuanDataset(Dataset):
|
||||
"""
|
||||
Yuan is an open source Chinese dataset, which can be accessed on https://github.com/Shawn-Inspur/Yuan-1.0.
|
||||
|
||||
Args:
|
||||
path(str): Path to dataset's folder, raw data should be organized under the folder as 001.txt, 002.txt...
|
||||
eg:/path/yuan/dataset
|
||||
vocab_path(str): Path to the vocab file. eg:/path/yuan/vocab.txt
|
||||
seq_len(int): Sequence length of the transformer, defaults to 2048.
|
||||
"""
|
||||
|
||||
def __init__(self, path, vocab_path, seq_len=2048) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.input_path = path
|
||||
workers = 16
|
||||
sentence_splitter = None
|
||||
self.vocab_path = vocab_path
|
||||
self.pad_id = EncDecTokenizer(os.path.join(self.vocab_path)).pad_id
|
||||
self.length = seq_len
|
||||
|
||||
if self.input_path[-1] == '/':
|
||||
self.input_path = self.input_path[:-1]
|
||||
if os.path.exists(os.path.join(self.input_path, 'data_list.pt')):
|
||||
self.data_path = torch.load(os.path.join(self.input_path, 'data_list.pt'))
|
||||
return
|
||||
|
||||
fin_list = glob.glob(self.input_path + '/0[0-9][0-9].txt')
|
||||
self.data_path = []
|
||||
for fin_path in fin_list:
|
||||
if not os.path.exists(fin_path):
|
||||
continue
|
||||
if '.txt' not in fin_path:
|
||||
continue
|
||||
|
||||
all_data = []
|
||||
print("Processing ", fin_path)
|
||||
with open(fin_path, 'r', encoding='utf-8', errors='ignore') as fin:
|
||||
|
||||
encoder = Encoder(self.vocab_path, seq_len, sentence_splitter)
|
||||
pool = multiprocessing.Pool(workers, initializer=encoder.initializer)
|
||||
encoded_docs = pool.imap_unordered(encoder.encode, fin, 30)
|
||||
|
||||
for i, (no_noise_tokens, bytes_processed) in tqdm(enumerate(encoded_docs, start=1)):
|
||||
if no_noise_tokens is None:
|
||||
continue
|
||||
all_data.append(no_noise_tokens)
|
||||
|
||||
pool.close()
|
||||
|
||||
print('Saving ', fin_path)
|
||||
base_path = fin_path.replace('.txt', '')
|
||||
if not os.path.exists(base_path):
|
||||
os.mkdir(base_path)
|
||||
idx = 0
|
||||
for d in tqdm(all_data):
|
||||
idx += 1
|
||||
cur_path = os.path.join(base_path, str(idx) + '.txt')
|
||||
with open(cur_path, 'w+', encoding='utf-8') as f:
|
||||
for i in d:
|
||||
f.write(str(i) + ' ')
|
||||
f.write('\n')
|
||||
self.data_path.append(cur_path.replace(self.input_path + '/', ''))
|
||||
|
||||
torch.save(self.data_path, os.path.join(self.input_path, 'data_list.pt'))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_path)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.data_path[index]
|
||||
root = os.path.join(self.input_path, path)
|
||||
with open(root, "r") as f:
|
||||
data = f.readlines()
|
||||
assert len(data) == 1
|
||||
data = data[0][:-2].split(' ')
|
||||
try:
|
||||
data = list(map(int, data))
|
||||
except:
|
||||
while '' in data:
|
||||
data.remove('')
|
||||
data = list(map(int, data))
|
||||
if len(data) > self.length:
|
||||
data = data[:self.length - 1] + [data[-1]]
|
||||
mask = [1] * self.length
|
||||
else:
|
||||
data += [self.pad_id] * (self.length - len(data))
|
||||
mask = [1] * len(data) + [0] * (self.length - len(data))
|
||||
|
||||
data = torch.tensor(data)
|
||||
mask = torch.tensor(mask)
|
||||
return {'input_ids': data, 'attention_mask': mask}, data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset = YuanDataset('/data/gpt-yuan/ASC22/dataset', vocab_path='/data/gpt-yuan/ASC22/vocab.txt', seq_len=2048)
|
||||
test = dataset.__getitem__(0)
|
||||
print(test)
|
|
@ -1,31 +0,0 @@
|
|||
from titans.loss.lm_loss import GPTLMLoss
|
||||
from titans.model.gpt import gpt2_small
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LEN = 1024
|
||||
NUM_EPOCHS = 60
|
||||
|
||||
TENSOR_PARALLEL = 2
|
||||
|
||||
optimizer = dict(
|
||||
type=Adam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
|
||||
loss = dict(type=GPTLMLoss,)
|
||||
|
||||
model = dict(
|
||||
type=gpt2_small,
|
||||
checkpoint=True,
|
||||
)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(size=TENSOR_PARALLEL, mode='1d'),
|
||||
)
|
|
@ -1,30 +0,0 @@
|
|||
from titans.loss.lm_loss import GPTLMLoss
|
||||
from titans.model.gpt import gpt2_small
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 4
|
||||
SEQ_LEN = 1024
|
||||
NUM_EPOCHS = 60
|
||||
TENSOR_PARALLEL = 4
|
||||
|
||||
optimizer = dict(
|
||||
type=Adam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
|
||||
loss = dict(type=GPTLMLoss,)
|
||||
|
||||
model = dict(
|
||||
type=gpt2_small,
|
||||
checkpoint=True,
|
||||
)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(size=TENSOR_PARALLEL, mode='2d'),
|
||||
)
|
|
@ -1,31 +0,0 @@
|
|||
from titans.loss.lm_loss import GPTLMLoss
|
||||
from titans.model.gpt import gpt2_small
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 4
|
||||
SEQ_LEN = 1024
|
||||
NUM_EPOCHS = 60
|
||||
TENSOR_PARALLEL = 8
|
||||
DEPTH = 2
|
||||
|
||||
optimizer = dict(
|
||||
type=Adam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
|
||||
loss = dict(type=GPTLMLoss,)
|
||||
|
||||
model = dict(
|
||||
type=gpt2_small,
|
||||
checkpoint=True,
|
||||
)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(size=TENSOR_PARALLEL, depth=DEPTH, mode='2.5d'),
|
||||
)
|
|
@ -1,30 +0,0 @@
|
|||
from titans.loss.lm_loss import GPTLMLoss
|
||||
from titans.model.gpt import gpt2_small
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 4
|
||||
SEQ_LEN = 1024
|
||||
NUM_EPOCHS = 60
|
||||
TENSOR_PARALLEL = 8
|
||||
|
||||
optimizer = dict(
|
||||
type=Adam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
|
||||
loss = dict(type=GPTLMLoss,)
|
||||
|
||||
model = dict(
|
||||
type=gpt2_small,
|
||||
checkpoint=True,
|
||||
)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(size=TENSOR_PARALLEL, mode='3d'),
|
||||
)
|
|
@ -1,33 +0,0 @@
|
|||
from titans.loss.lm_loss import GPTLMLoss
|
||||
from titans.model.gpt import gpt2_small
|
||||
#from model_zoo.gpt.gpt import gpt2_small_pipeline
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LEN = 1024
|
||||
NUM_EPOCHS = 60
|
||||
HIDDEN_SIZE = 768
|
||||
NUM_MICRO_BATCHES = 4
|
||||
PIPELINE = 2
|
||||
|
||||
optimizer = dict(
|
||||
type=Adam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
|
||||
loss = dict(type=GPTLMLoss,)
|
||||
|
||||
model = dict(
|
||||
type=gpt2_small,
|
||||
checkpoint=True,
|
||||
)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=PIPELINE,
|
||||
tensor=dict(size=1, mode=None),
|
||||
)
|
|
@ -1,35 +0,0 @@
|
|||
import torch
|
||||
from titans.loss.lm_loss import GPTLMLoss
|
||||
from titans.loss.vocab_cross_entropy import vocab_parallel_cross_entropy
|
||||
from titans.model.gpt import gpt2_small
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 8
|
||||
NUM_EPOCHS = 60
|
||||
SEQ_LEN = 1024
|
||||
|
||||
NUM_MICRO_BATCHES = 4
|
||||
HIDDEN_SIZE = 768
|
||||
PIPELINE = 2
|
||||
TENSOR_PARALLEL = 2
|
||||
MODE = '1d'
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
|
||||
parallel = dict(pipeline=PIPELINE, tensor=dict(mode=MODE, size=TENSOR_PARALLEL))
|
||||
|
||||
optimizer = dict(
|
||||
type=Adam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type=gpt2_small,
|
||||
checkpoint=True,
|
||||
dtype=torch.half,
|
||||
)
|
||||
|
||||
loss_fn = dict(type=vocab_parallel_cross_entropy)
|
|
@ -1,26 +0,0 @@
|
|||
from titans.model.gpt import gpt2_small
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 1
|
||||
NUM_EPOCHS = 60
|
||||
SEQ_LEN = 1024
|
||||
|
||||
optimizer = dict(
|
||||
type=Adam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
|
||||
model = dict(
|
||||
type=gpt2_small,
|
||||
checkpoint=True,
|
||||
)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(size=1, mode=None),
|
||||
)
|
|
@ -1,24 +0,0 @@
|
|||
from titans.model.gpt import gpt2_small
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
|
||||
BATCH_SIZE = 2
|
||||
NUM_EPOCHS = 60
|
||||
SEQ_LEN = 1024
|
||||
|
||||
zero = dict(model_config=dict(tensor_placement_policy='auto',
|
||||
shard_strategy=TensorShardStrategy(),
|
||||
reuse_fp16_shard=True),
|
||||
optimizer_config=dict())
|
||||
|
||||
optimizer = dict(
|
||||
type=HybridAdam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type=gpt2_small,
|
||||
checkpoint=True,
|
||||
)
|
|
@ -1,26 +0,0 @@
|
|||
from model import GPT2_small_pipeline_hybrid
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
|
||||
BATCH_SIZE = 8
|
||||
NUM_EPOCHS = 60
|
||||
SEQ_LEN = 1024
|
||||
NUM_MICRO_BATCHES = 4
|
||||
HIDDEN_SIZE = 768
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
|
||||
zero = dict(model_config=dict(tensor_placement_policy='cpu', shard_strategy=BucketTensorShardStrategy()),
|
||||
optimizer_config=dict())
|
||||
|
||||
optimizer = dict(
|
||||
type=HybridAdam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
model = dict(type=GPT2_small_pipeline_hybrid, checkpoint=True, num_chunks=1)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=2,
|
||||
tensor=dict(size=2, mode='1d'),
|
||||
)
|
|
@ -1,30 +0,0 @@
|
|||
import torch
|
||||
from titans.loss.vocab_cross_entropy import vocab_parallel_cross_entropy
|
||||
from titans.model.gpt import gpt3
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 192
|
||||
NUM_EPOCHS = 60
|
||||
SEQ_LEN = 2048
|
||||
NUM_MICRO_BATCHES = 192
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, 12288)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
|
||||
parallel = dict(pipeline=32, tensor=dict(mode='1d', size=4))
|
||||
|
||||
optimizer = dict(
|
||||
type=Adam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type=gpt3,
|
||||
checkpoint=True,
|
||||
dtype=torch.half,
|
||||
)
|
||||
|
||||
loss_fn = dict(type=vocab_parallel_cross_entropy)
|
|
@ -1,30 +0,0 @@
|
|||
import torch
|
||||
from titans.loss.vocab_cross_entropy import vocab_parallel_cross_entropy
|
||||
from titans.model.gpt import gpt3
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 192
|
||||
NUM_EPOCHS = 60
|
||||
SEQ_LEN = 2048
|
||||
NUM_MICRO_BATCHES = 192
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, 12288)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
|
||||
parallel = dict(pipeline=24, tensor=dict(mode='1d', size=4))
|
||||
|
||||
optimizer = dict(
|
||||
type=Adam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type=gpt3,
|
||||
checkpoint=True,
|
||||
dtype=torch.half,
|
||||
)
|
||||
|
||||
loss_fn = dict(type=vocab_parallel_cross_entropy)
|
|
@ -1,27 +0,0 @@
|
|||
import torch
|
||||
from titans.model.gpt import gpt3
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 2 * 48
|
||||
NUM_EPOCHS = 60
|
||||
SEQ_LEN = 2048
|
||||
NUM_MICRO_BATCHES = 48
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES // 2, SEQ_LEN, 12288 // 2)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
|
||||
parallel = dict(pipeline=24, tensor=dict(mode='2d', size=4))
|
||||
|
||||
optimizer = dict(
|
||||
type=Adam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type=gpt3,
|
||||
checkpoint=True,
|
||||
dtype=torch.half,
|
||||
)
|
|
@ -1,27 +0,0 @@
|
|||
import torch
|
||||
from titans.model.gpt import gpt3
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 2 * 48
|
||||
NUM_EPOCHS = 60
|
||||
SEQ_LEN = 2048
|
||||
NUM_MICRO_BATCHES = 48
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES // 2, SEQ_LEN, 12288 // 2)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
|
||||
parallel = dict(pipeline=24, tensor=dict(mode='2.5d', depth=1, size=4))
|
||||
|
||||
optimizer = dict(
|
||||
type=Adam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type=gpt3,
|
||||
checkpoint=True,
|
||||
dtype=torch.half,
|
||||
)
|
|
@ -0,0 +1,3 @@
|
|||
colossalai >= 0.1.10
|
||||
torch >= 1.8.1
|
||||
transformers >= 4.231
|
|
@ -1,7 +1 @@
|
|||
export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
|
||||
|
||||
export NODE_RANK=${NODE_RANK:-0}
|
||||
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
||||
export MASTER_PORT=${MASTER_PORT:-"12345"}
|
||||
|
||||
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=2 train_gpt.py --config=gpt2_configs/gpt2_zero3.py --from_torch 2>&1 | tee logs/log
|
||||
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=2 train_gpt_demo.py 2>&1 | tee run.log
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,307 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import glob
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
||||
import tldextract
|
||||
|
||||
# List of the domains to blacklist.
|
||||
domain_blacklist = set([
|
||||
'500px',
|
||||
'aapks',
|
||||
'akamaihd',
|
||||
'amazon',
|
||||
'apple',
|
||||
'artifactfire',
|
||||
'artstation',
|
||||
'awwni',
|
||||
'bandcamp',
|
||||
'battleforthenet',
|
||||
'coinscalendar',
|
||||
'dailymotion',
|
||||
'deviantart',
|
||||
'discord',
|
||||
'discordapp',
|
||||
'dlapkandroid',
|
||||
'dropbox',
|
||||
'e621',
|
||||
'ebay',
|
||||
'edealinfo',
|
||||
'erome',
|
||||
'eroshare',
|
||||
'explosm',
|
||||
'facebook',
|
||||
'fbcdn',
|
||||
'flickr',
|
||||
'furaffinity',
|
||||
'futhead',
|
||||
'gatopardo',
|
||||
'gfycat',
|
||||
'gifsound',
|
||||
'gifsoup',
|
||||
'giphy',
|
||||
'github',
|
||||
'google',
|
||||
'gunprime',
|
||||
'gyazo',
|
||||
'hotdealstar',
|
||||
'imagefap',
|
||||
'imageshack',
|
||||
'imgflip',
|
||||
'imgur',
|
||||
'instagram',
|
||||
'karmadecay',
|
||||
'kryptocal',
|
||||
'kym-cdn',
|
||||
'liveleak',
|
||||
'livememe',
|
||||
'lmgtfy',
|
||||
'magaimg',
|
||||
'memegenerator',
|
||||
'minorplanetcenter',
|
||||
'minus',
|
||||
'mobafire',
|
||||
'morejpeg',
|
||||
'nocookie',
|
||||
'pcpartpicker',
|
||||
'photobucket',
|
||||
'pinimg',
|
||||
'pinterest',
|
||||
'pixiv',
|
||||
'pornhub',
|
||||
'prntscr',
|
||||
'puu',
|
||||
'qkme',
|
||||
'quickmeme',
|
||||
'radd',
|
||||
'redd',
|
||||
'reddit',
|
||||
'reddit-stream',
|
||||
'redditlog',
|
||||
'redditmedia',
|
||||
'reddituploads',
|
||||
'redtube',
|
||||
'reupp',
|
||||
'reverb',
|
||||
'roanoke',
|
||||
'rollingstone',
|
||||
'sli',
|
||||
'soundcloud',
|
||||
'soundgasm',
|
||||
'spankbang',
|
||||
'spotify',
|
||||
'strawpoll',
|
||||
'streamable',
|
||||
'timeanddate',
|
||||
'tinypic',
|
||||
'touhouradio',
|
||||
'tumblr',
|
||||
'twimg',
|
||||
'twitch',
|
||||
'twitter',
|
||||
'vid',
|
||||
'vimeo',
|
||||
'vine',
|
||||
'vkaao',
|
||||
'vocaroo',
|
||||
'voyagefusion',
|
||||
'walmart',
|
||||
'wciu',
|
||||
'wikimedia',
|
||||
'wikipedia',
|
||||
'xhamster',
|
||||
'xkcd',
|
||||
'xvideos',
|
||||
'youtu',
|
||||
'youtube',
|
||||
'youtubedoubler',
|
||||
'ytimg',
|
||||
'zillexplorer',
|
||||
])
|
||||
|
||||
|
||||
def domain_is_in_blacklist(url):
|
||||
domain = tldextract.extract(url).domain
|
||||
return domain in domain_blacklist
|
||||
|
||||
|
||||
# List of extentions to blacklist.
|
||||
extentions_blacklist = (
|
||||
'.3gp',
|
||||
'.7z'
|
||||
'.ai',
|
||||
'.aif',
|
||||
'.apk',
|
||||
'.app',
|
||||
'.avi',
|
||||
'.bin',
|
||||
'.bmp',
|
||||
'.bz2',
|
||||
'.css',
|
||||
'.csv',
|
||||
'.dat',
|
||||
'.deb',
|
||||
'.dmg',
|
||||
'.doc',
|
||||
'.docx',
|
||||
'.exe',
|
||||
'.gif',
|
||||
'.gifv',
|
||||
'.gz',
|
||||
'.iso',
|
||||
'.jar',
|
||||
'.jpeg',
|
||||
'.jpg',
|
||||
'.js',
|
||||
'.log',
|
||||
'.mid',
|
||||
'.midi',
|
||||
'.mkv',
|
||||
'.mov',
|
||||
'.mp3',
|
||||
'.mp4',
|
||||
'.mpeg',
|
||||
'.mpg',
|
||||
'.ogg',
|
||||
'.ogv',
|
||||
'.otf',
|
||||
'.pdf',
|
||||
'.pkg',
|
||||
'.png',
|
||||
'.pps',
|
||||
'.ppt',
|
||||
'.pptx',
|
||||
'.psd',
|
||||
'.py',
|
||||
'.qt',
|
||||
'.ram',
|
||||
'.rar',
|
||||
'.sql',
|
||||
'.svg',
|
||||
'.swf',
|
||||
'.tar.gz',
|
||||
'.tar',
|
||||
'.tgz',
|
||||
'.tiff',
|
||||
'.ttf',
|
||||
'.txt',
|
||||
'.wav',
|
||||
'.webm',
|
||||
'.wma',
|
||||
'.wmv',
|
||||
'.xls',
|
||||
'.xlsx',
|
||||
'.xml',
|
||||
'.xz',
|
||||
'.zip',
|
||||
)
|
||||
|
||||
|
||||
def extention_is_in_blacklist(url):
|
||||
if url.split('?')[0].lower().endswith(extentions_blacklist):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Malformed urls.
|
||||
# This function is adapted from:
|
||||
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
|
||||
url_regex = re.compile(
|
||||
r'^(?:http)s?://' # http:// or https://
|
||||
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
|
||||
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
|
||||
r'(?::\d+)?' # optional port
|
||||
r'(?:/?|[/?]\S+)$',
|
||||
re.IGNORECASE)
|
||||
|
||||
|
||||
def url_is_malformed(url):
|
||||
return re.match(url_regex, url) is None
|
||||
|
||||
|
||||
def print_progress(prefix, start_time, urls_counter, domain_blacklist_counter, extention_blacklist_counter,
|
||||
short_url_counter, malformed_url_counter, duplicate_url_counter):
|
||||
string = prefix + ' | '
|
||||
string += 'time elapsed (s): {:.2f} | '.format(time.time() - start_time)
|
||||
string += 'number of urls: {} | '.format(urls_counter)
|
||||
string += 'domain blacklisted: {} | '.format(domain_blacklist_counter)
|
||||
string += 'extention blacklisted: {} | '.format(extention_blacklist_counter)
|
||||
string += 'short urls (<=8): {} | '.format(short_url_counter)
|
||||
string += 'malformed urls: {} | '.format(malformed_url_counter)
|
||||
string += 'duplicate urls: {}'.format(duplicate_url_counter)
|
||||
print(string, flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
print('remove blacklisted urls ..')
|
||||
|
||||
# Path to the url files.
|
||||
path = sys.argv[1]
|
||||
# Output url file.
|
||||
output = sys.argv[2]
|
||||
|
||||
# Get the list of url files.
|
||||
files = glob.glob(path + '/*.txt')
|
||||
print('> found {} files'.format(len(files)))
|
||||
|
||||
urls = set()
|
||||
urls_counter = 0
|
||||
domain_blacklist_counter = 0
|
||||
extention_blacklist_counter = 0
|
||||
short_url_counter = 0
|
||||
malformed_url_counter = 0
|
||||
duplicate_url_counter = 0
|
||||
start_time = time.time()
|
||||
for filename in files:
|
||||
with open(filename, 'r') as f:
|
||||
for line in f:
|
||||
url = line.strip()
|
||||
urls_counter += 1
|
||||
if domain_is_in_blacklist(url):
|
||||
print('[DOMAIN BLACKLIST]: {}'.format(url), flush=True)
|
||||
domain_blacklist_counter += 1
|
||||
elif extention_is_in_blacklist(url):
|
||||
print('[EXTENTION BLACKLIST]: {}'.format(url), flush=True)
|
||||
extention_blacklist_counter += 1
|
||||
elif len(url) <= 8:
|
||||
print('[SHORT URL]: {}'.format(url), flush=True)
|
||||
short_url_counter += 1
|
||||
elif url_is_malformed(url):
|
||||
print('[MALFORMED URL]: {}'.format(url), flush=True)
|
||||
malformed_url_counter += 1
|
||||
elif url in urls:
|
||||
print('[DUPLICATE URL]: {}'.format(url), flush=True)
|
||||
duplicate_url_counter += 1
|
||||
else:
|
||||
urls.add(url)
|
||||
if urls_counter % 100000 == 0:
|
||||
print_progress('PROGRESS', start_time, urls_counter, domain_blacklist_counter,
|
||||
extention_blacklist_counter, short_url_counter, malformed_url_counter,
|
||||
duplicate_url_counter)
|
||||
|
||||
print_progress('FINAL', start_time, urls_counter, domain_blacklist_counter, extention_blacklist_counter,
|
||||
short_url_counter, malformed_url_counter, duplicate_url_counter)
|
||||
|
||||
# Write the final set of urls.
|
||||
print('> writing cleaned up url list to {}'.format(output))
|
||||
with open(output, 'w') as f:
|
||||
for url in urls:
|
||||
f.write(url + '\n')
|
||||
|
||||
print('done :-)')
|
|
@ -1,107 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import ftfy
|
||||
import numpy as np
|
||||
from langdetect import detect
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
MIN_DOCUMENT_LENGTH = 128
|
||||
|
||||
|
||||
def print_progress(prefix, start_time, num_docs, num_fixed_text, num_non_english_docs, chars_non_english_docs,
|
||||
num_small_docs, chars_small_docs):
|
||||
|
||||
string = prefix + ' | '
|
||||
string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
|
||||
string += 'documents: {} | '.format(num_docs)
|
||||
string += 'fixed text: {} | '.format(num_fixed_text)
|
||||
string += 'non-english: {} | '.format(num_non_english_docs)
|
||||
string += 'non-english chars: {} | '.format(chars_non_english_docs)
|
||||
string += 'small docs: {} | '.format(num_small_docs)
|
||||
string += 'small docs chars: {}'.format(chars_small_docs)
|
||||
print(string, flush=True)
|
||||
|
||||
|
||||
def filter_corpus(filename, out_filename, print_interval=10000):
|
||||
|
||||
print(' > filtering {}'.format(filename))
|
||||
|
||||
tokenizer = Tokenizer(cache_dir='./cache')
|
||||
|
||||
num_docs = 0
|
||||
num_written_docs = 0
|
||||
num_small_docs = 0
|
||||
num_fixed_text = 0
|
||||
num_non_english_docs = 0
|
||||
chars_non_english_docs = 0
|
||||
chars_small_docs = 0
|
||||
start_time = time.time()
|
||||
with open(out_filename, 'wb') as f:
|
||||
with open(filename, 'r') as fin:
|
||||
for line in fin:
|
||||
try:
|
||||
num_docs += 1
|
||||
myjson = json.loads(line)
|
||||
# Fix text
|
||||
text = ftfy.fix_text(myjson['text'])
|
||||
if text != myjson['text']:
|
||||
num_fixed_text += 1
|
||||
myjson['text'] = text
|
||||
# Detect language.
|
||||
if detect(text) != 'en':
|
||||
print('[non-english text]', myjson)
|
||||
num_non_english_docs += 1
|
||||
chars_non_english_docs += len(text)
|
||||
continue
|
||||
# On average each token is 5 characters so 8 is an
|
||||
# upper bound.
|
||||
if len(text) < (8 * MIN_DOCUMENT_LENGTH):
|
||||
tokens = tokenizer.tokenize_document(text)
|
||||
if len(tokens) < MIN_DOCUMENT_LENGTH:
|
||||
print('[small document, skipping]:', myjson)
|
||||
num_small_docs += 1
|
||||
chars_small_docs += len(text)
|
||||
continue
|
||||
myjson = json.dumps(myjson, ensure_ascii=False)
|
||||
f.write(myjson.encode('utf-8'))
|
||||
f.write('\n'.encode('utf-8'))
|
||||
num_written_docs += 1
|
||||
if num_docs % print_interval == 0:
|
||||
print_progress('[PROGRESS]', start_time, num_docs, num_fixed_text, num_non_english_docs,
|
||||
chars_non_english_docs, num_small_docs, chars_small_docs)
|
||||
except Exception as e:
|
||||
print(' skipping ', line, e)
|
||||
|
||||
print_progress('[FINAL]', start_time, num_docs, num_fixed_text, num_non_english_docs, chars_non_english_docs,
|
||||
num_small_docs, chars_small_docs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
print('building gpt2 dataset ...')
|
||||
|
||||
input_filename = sys.argv[1]
|
||||
output_filename = sys.argv[2]
|
||||
|
||||
print('will be reading {}'.format(input_filename))
|
||||
print('and will write the results to {}'.format(output_filename))
|
||||
|
||||
filter_corpus(input_filename, output_filename)
|
|
@ -1,191 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Filter and clean documents:
|
||||
Capable to clean docs with less than 512 characters, less than
|
||||
256 characters and contains javascript, fix text and dataset specific
|
||||
cleaning like stories and realnews datasets.
|
||||
Program arguments have the details.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import ftfy
|
||||
from langdetect import detect
|
||||
|
||||
|
||||
def process_doc(json_line, args):
|
||||
|
||||
# Read the line.
|
||||
document = json.loads(json_line)
|
||||
text = document['text']
|
||||
|
||||
output = {'remove_512': False, 'remove_256_javascript': False, \
|
||||
'remove_512_non_english': False, 'ftfy_fix_text': False, \
|
||||
'general_cleaning': False}
|
||||
|
||||
try:
|
||||
# Remove all docs with less than 512 characters
|
||||
if "remove_512" in args.tasks:
|
||||
if len(text) < 512:
|
||||
output['remove_512'] = True
|
||||
return output, text, document, True
|
||||
|
||||
# Remove docs if less than 256 character length and contains Javascript
|
||||
if "remove_256_javascript" in args.tasks:
|
||||
if len(text) < 256 and 'javascript' in text.lower():
|
||||
output['remove_256_javascript'] = True
|
||||
return output, text, document, True
|
||||
|
||||
# Remove docs < 512 and nonenglish
|
||||
if "remove_512_non_english" in args.tasks:
|
||||
if len(text) < 512 and detect(text) != 'en':
|
||||
output['remove_512_non_english'] = True
|
||||
return output, text, document, True
|
||||
|
||||
# Fix the text using ftfy, don't remove the text, hence return False
|
||||
if "ftfy_fix_text" in args.tasks:
|
||||
fixed_text = ftfy.fix_text(text)
|
||||
output['ftfy_fix_text'] = True
|
||||
return output, fixed_text, document, False
|
||||
|
||||
# Cleaning extra spaces and newlines
|
||||
if "general_cleaning" in args.tasks:
|
||||
cleaned_text = re.sub(r" +|\b\n+ |\b\n+", " ", text)
|
||||
#cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
|
||||
#cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews
|
||||
|
||||
# stories datasets
|
||||
#cleaned_text = re.sub(r" \'", "'", text)
|
||||
#cleaned_text = re.sub(r" \!", "!", cleaned_text)
|
||||
#cleaned_text = re.sub(r" \.", ".", cleaned_text)
|
||||
#cleaned_text = re.sub(r" \?", "?", cleaned_text)
|
||||
#cleaned_text = re.sub(r" - ", "-", cleaned_text)
|
||||
##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
|
||||
#cleaned_text = re.sub(r" @ ", "@", cleaned_text)
|
||||
|
||||
output['general_cleaning'] = True
|
||||
return output, cleaned_text, document, False
|
||||
|
||||
except Exception as e:
|
||||
print('Error: *************************\n{}\ntext: {}'.format(e, \
|
||||
text), flush=True)
|
||||
return output, text, document, True
|
||||
|
||||
# don't remove
|
||||
return output, text, document, False
|
||||
|
||||
|
||||
def process_set(args, input_file, output_f_cleaned, output_f_filtered):
|
||||
|
||||
print(' > working on {} ...'.format(input_file), flush=True)
|
||||
|
||||
num_docs = num_remove_512 = num_remove_java = num_remove_512_non_english \
|
||||
= num_ftfy_fix_text = num_general_cleaning = 0
|
||||
|
||||
# Output file and counters.
|
||||
output_cleaned = open(output_f_cleaned, 'wb')
|
||||
output_filtered = open(output_f_filtered, 'wb')
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Setup multi-processing.
|
||||
num_workers = 40
|
||||
fin = open(input_file, 'r', encoding='utf-8')
|
||||
pool = multiprocessing.Pool(num_workers)
|
||||
process_doc_partial = partial(process_doc, args=args)
|
||||
processed_docs = pool.imap(process_doc_partial, fin, 500)
|
||||
|
||||
# Process documents.
|
||||
for output, text, document, to_filter in processed_docs:
|
||||
num_docs += 1
|
||||
|
||||
num_remove_512 += 1 if output['remove_512'] else 0
|
||||
num_remove_java += 1 if output['remove_256_javascript'] else 0
|
||||
num_remove_512_non_english += 1 if output['remove_512_non_english'] \
|
||||
else 0
|
||||
num_ftfy_fix_text += 1 if output['ftfy_fix_text'] else 0
|
||||
num_general_cleaning += 1 if output['general_cleaning'] else 0
|
||||
|
||||
document['text'] = text
|
||||
myjson = json.dumps(document, ensure_ascii=False)
|
||||
|
||||
if to_filter:
|
||||
output_filtered.write(myjson.encode('utf-8'))
|
||||
output_filtered.write('\n'.encode('utf-8'))
|
||||
else:
|
||||
output_cleaned.write(myjson.encode('utf-8'))
|
||||
output_cleaned.write('\n'.encode('utf-8'))
|
||||
|
||||
if num_docs % args.log_interval == 0:
|
||||
print(' processed {:9d} documents in {:.2f} seconds ...'.format(num_docs,
|
||||
time.time() - start_time),
|
||||
flush=True)
|
||||
|
||||
# Close the file.
|
||||
output_cleaned.close()
|
||||
output_filtered.close()
|
||||
fin.close()
|
||||
|
||||
# Print stats.
|
||||
print(' >> total docs: {} remove_512 {} remove_256_javascript {} '\
|
||||
'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'.\
|
||||
format(num_docs, num_remove_512, num_remove_java,\
|
||||
num_remove_512_non_english, num_ftfy_fix_text, \
|
||||
num_general_cleaning), flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
print('parsing the arguments ...')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input-files', nargs = '*', required=True, default=\
|
||||
None, help = 'Input json files that needs to be'\
|
||||
' cleaned')
|
||||
parser.add_argument('--tasks', nargs = '*', required=True, default=None,\
|
||||
help = 'Tasks to perform on the input files, ' \
|
||||
'such as remove_512, remove_256_javascript, ' \
|
||||
'remove_512_non_english, ftfy_fix_text, and ' \
|
||||
'general_cleaning. 256 or 512 means the number' \
|
||||
' of characters.')
|
||||
|
||||
parser.add_argument('--output-path', type=str, default=None, help='Directory where the output should go')
|
||||
parser.add_argument('--log-interval', type=int, default=100, help='Log interval')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print('cleanup dataset ...')
|
||||
|
||||
for input_file in args.input_files:
|
||||
input_filename, input_filename_ext = os.path.splitext(Path(input_file)\
|
||||
.name)
|
||||
|
||||
output_f_cleaned = os.path.join(args.output_path, input_filename + \
|
||||
"_cleaned" + input_filename_ext)
|
||||
output_f_filtered = os.path.join(args.output_path, input_filename + \
|
||||
"_filtered" + input_filename_ext)
|
||||
|
||||
process_set(args, input_file, output_f_cleaned, output_f_filtered)
|
||||
|
||||
print('done :-)', flush=True)
|
|
@ -1,314 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from lsh import cache, minhash
|
||||
|
||||
|
||||
# This function is adapted from:
|
||||
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
|
||||
def shingles(text, char_ngram=5):
|
||||
return set(text[head:head + char_ngram] for head in range(0, len(text) - char_ngram))
|
||||
|
||||
|
||||
# This function is adapted from:
|
||||
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
|
||||
def jaccard(set_a, set_b, args):
|
||||
if len(set_a) < 1 or len(set_b) < 1:
|
||||
return 0.0
|
||||
|
||||
intersection = set_a & set_b
|
||||
union = set_a | set_b
|
||||
|
||||
if args.jaccard == 'min':
|
||||
return len(intersection) / min(len(set_a), len(set_b))
|
||||
elif args.jaccard == 'max':
|
||||
return len(intersection) / max(len(set_a), len(set_b))
|
||||
else:
|
||||
return len(intersection) / len(union)
|
||||
|
||||
|
||||
def compute_fingerprint(line, key):
|
||||
try:
|
||||
myjson = json.loads(line)
|
||||
url = myjson[key]
|
||||
text = myjson['text']
|
||||
fingerprint = hasher.fingerprint(text)
|
||||
except Exception as e:
|
||||
print('Error:', e)
|
||||
return None, None, None, False
|
||||
|
||||
return url, text, fingerprint, True
|
||||
|
||||
|
||||
def url_pairs_to_remove(args, bucket_urls, url_doc):
|
||||
remove_urls_list = []
|
||||
deduped_local, counter_local = 0, 0
|
||||
iteration = 0
|
||||
while len(bucket_urls) > 1:
|
||||
if args.heuristic_iter != -1 and \
|
||||
iteration == args.heuristic_iter:
|
||||
break
|
||||
|
||||
items = list(bucket_urls)
|
||||
remove_urls = []
|
||||
main_url = items[np.random.randint(0, len(items))]
|
||||
main_shingles = shingles(url_doc[main_url])
|
||||
|
||||
for i in range(0, len(items)):
|
||||
counter_local += 1
|
||||
other_url = items[i]
|
||||
if other_url == main_url:
|
||||
continue
|
||||
other_shingles = shingles(url_doc[other_url])
|
||||
try:
|
||||
jaccard_sim = jaccard(main_shingles, other_shingles, args)
|
||||
except Exception as e:
|
||||
print('Error:', e)
|
||||
jaccard_sim = 0.0
|
||||
if jaccard_sim > 0.5:
|
||||
remove_urls.append({other_url: jaccard_sim})
|
||||
deduped_local += 1
|
||||
bucket_urls.remove(other_url)
|
||||
|
||||
bucket_urls.remove(main_url)
|
||||
if len(remove_urls) > 0:
|
||||
remove_urls_list.append({main_url: remove_urls})
|
||||
iteration += 1
|
||||
return remove_urls_list, deduped_local, counter_local
|
||||
|
||||
|
||||
def write_remove_urls_list(remove_urls_list, f_out):
|
||||
if len(remove_urls_list) > 0:
|
||||
for each_url_remove in remove_urls_list:
|
||||
myjson = json.dumps(each_url_remove, ensure_ascii=False)
|
||||
f_out.write(myjson.encode('utf-8'))
|
||||
f_out.write('\n'.encode('utf-8'))
|
||||
|
||||
|
||||
def compute_jaccard(each_bin, num_bins, start_time_local):
|
||||
|
||||
remove_urls_list = []
|
||||
deduped_local, counter_local, bucket_local = 0, 0, 0
|
||||
|
||||
for bucket_id in each_bin:
|
||||
bucket_local += 1
|
||||
if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
|
||||
print("Counter {}, progress {:.2f} time {:.2f}".\
|
||||
format(bucket_local, float(bucket_local)/float(len(each_bin)),\
|
||||
time.time() - start_time_local), flush=True)
|
||||
|
||||
if len(each_bin[bucket_id]) <= 1:
|
||||
continue
|
||||
|
||||
bucket_urls = each_bin[bucket_id].copy()
|
||||
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
|
||||
url_pairs_to_remove(args, bucket_urls, url_doc)
|
||||
|
||||
deduped_local += deduped_local_sub
|
||||
counter_local += counter_local_sub
|
||||
if len(remove_urls_list_sub) > 0:
|
||||
remove_urls_list.extend(remove_urls_list_sub)
|
||||
|
||||
return remove_urls_list, deduped_local, counter_local
|
||||
|
||||
|
||||
def find_pair_urls_parallel(args, lshcache, url_doc):
|
||||
start_time = time.time()
|
||||
f_out = open(args.output, 'wb')
|
||||
deduped, counter = 0, 0
|
||||
|
||||
# compute jaccards of buckets in bin in parallel (parallelism
|
||||
# limited to # of bins)
|
||||
num_bins = len(lshcache.bins)
|
||||
pool = multiprocessing.Pool(num_bins)
|
||||
compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
|
||||
start_time_local=start_time)
|
||||
# don't need to pass args and url_doc as they are already shared
|
||||
compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)
|
||||
|
||||
print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
|
||||
flush=True)
|
||||
for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
|
||||
deduped += deduped_local
|
||||
counter += counter_local
|
||||
write_remove_urls_list(remove_urls_list, f_out)
|
||||
print(' [write]> processed {} documents in {:.2f} '
|
||||
'seconds and deduped {} documents ...'.format(counter, time.time()\
|
||||
- start_time, deduped), flush=True)
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
f_out.close()
|
||||
|
||||
print(' Taken time for jaccard similarities {:.2f} seconds'.format(\
|
||||
time.time() - start_time), flush=True)
|
||||
|
||||
|
||||
def find_pair_urls_sequential(args, lshcache, url_doc):
|
||||
start_time = time.time()
|
||||
f_out = open(args.output, 'wb')
|
||||
deduped, counter = 0, 0
|
||||
for b in lshcache.bins:
|
||||
for bucket_id in b:
|
||||
if len(b[bucket_id]) <= 1:
|
||||
continue
|
||||
|
||||
bucket_urls = b[bucket_id].copy()
|
||||
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
|
||||
url_pairs_to_remove(args, bucket_urls, url_doc)
|
||||
|
||||
deduped += deduped_local_sub
|
||||
counter += counter_local_sub
|
||||
write_remove_urls_list(remove_urls_list_sub, f_out)
|
||||
if counter % 10000 == 0:
|
||||
print(' [write]> processed {} documents in {:.2f} '
|
||||
'seconds and deduped {} documents ...'.format(counter,
|
||||
time.time() - start_time, deduped),
|
||||
flush=True)
|
||||
f_out.close()
|
||||
print(' [write]> processed {} documents in {:.2f} '
|
||||
'seconds and deduped {} documents ...'.format(counter,
|
||||
time.time() - start_time, deduped),
|
||||
flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
print('parsing the arguments ...')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--seed', type=int, default=1234, help='Random seed used for python, numpy')
|
||||
parser.add_argument('--inputs', nargs = '*', default=None, help = \
|
||||
'Pairwise list of the input files and keys, '
|
||||
'e.g. --inputs cc.json cc_id news.json news_id')
|
||||
parser.add_argument('--load-fingerprints',
|
||||
nargs='*',
|
||||
default=None,
|
||||
help='Load fingerprints from a list of pickle files,'
|
||||
' e.g. cc.pkl news.pkl')
|
||||
parser.add_argument('--save-fingerprints', type=str, default=None, help='Save the fingerprints of the inputs.')
|
||||
parser.add_argument('--output',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Output file name that consists of all ids'
|
||||
' with matching similarities')
|
||||
parser.add_argument('--jaccard', type=str, default='union',
|
||||
choices=['union', 'min', 'max'], help='Jaccard'\
|
||||
' similarity computation')
|
||||
parser.add_argument('--heuristic-iter',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of iterations to run the heuristics'
|
||||
': use -1 for exact')
|
||||
parser.add_argument('--num-bands', type=int, default=10, help='Number of bands to use in cache')
|
||||
parser.add_argument('--num-seeds',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Number of seeds to use for minhash. Note that'
|
||||
' this value should be divisible by num-bands')
|
||||
parser.add_argument('--jaccard-parallel',
|
||||
action='store_true',
|
||||
help='Use this to process large number of documents.')
|
||||
args = parser.parse_args()
|
||||
|
||||
print('finding possible duplicate content ...')
|
||||
|
||||
# set seed and get an array of seeds of 100 integers
|
||||
np.random.seed(args.seed)
|
||||
seeds = np.random.randint(0, 1e6, size=args.num_seeds)
|
||||
|
||||
# initialize minhash and lsh cache
|
||||
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
|
||||
lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
|
||||
|
||||
url_doc = {}
|
||||
|
||||
# load fingerprints from pickle file if needed
|
||||
if args.load_fingerprints is not None:
|
||||
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
|
||||
print("Loading fingerprints from pickle file {}".format(fp_file_name), flush=True)
|
||||
fp = open(fp_file_name, "rb")
|
||||
if count_fp == 0:
|
||||
# assign directory for the first pkl
|
||||
lshcache = pickle.load(fp)
|
||||
url_doc = pickle.load(fp)
|
||||
else:
|
||||
# append these to lshcache and url_doc
|
||||
local_lshcache = pickle.load(fp)
|
||||
local_url_doc = pickle.load(fp)
|
||||
for url in local_lshcache.fingerprints.keys():
|
||||
url_doc[url] = local_url_doc[url]
|
||||
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
|
||||
fp.close()
|
||||
|
||||
counter = 0
|
||||
start_time = time.time()
|
||||
|
||||
# compute finger prints of the inputs if any
|
||||
# input file and the key to use as id
|
||||
if args.inputs is not None:
|
||||
print("Computing fingerprints", flush=True)
|
||||
assert len(args.inputs) % 2 == 0
|
||||
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
|
||||
print(' document processing {} with key {}'.format(input_file, key), flush=True)
|
||||
|
||||
# compute fingerprints in parallel
|
||||
num_workers = 40
|
||||
pool = multiprocessing.Pool(num_workers)
|
||||
fin = open(input_file, 'r', encoding='utf-8')
|
||||
compute_fingerprint_partial = partial(compute_fingerprint, key=key)
|
||||
compute_fingerprint_iter = pool.imap(compute_fingerprint_partial, fin, 512)
|
||||
# traverse all the texts and add fingerprints
|
||||
for url, text, fingerprint, flag in compute_fingerprint_iter:
|
||||
counter += 1
|
||||
if flag:
|
||||
url_doc[url] = text
|
||||
lshcache.add_fingerprint(fingerprint, url)
|
||||
if counter % 10000 == 0:
|
||||
print(' [read]> processed {} documents in {:.2f} '
|
||||
'seconds ...'.format(counter, time.time() - \
|
||||
start_time), flush=True)
|
||||
|
||||
fin.close()
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
# Save the fingerprints if needed
|
||||
if args.save_fingerprints is not None:
|
||||
print("Saving fingerprints to pickle file {}".format(args.save_fingerprints), flush=True)
|
||||
with open(args.save_fingerprints, 'wb') as f_save:
|
||||
pickle.dump(lshcache, f_save)
|
||||
pickle.dump(url_doc, f_save)
|
||||
|
||||
# compute jaccard index of the input texts and write to file if needed
|
||||
if args.output is not None:
|
||||
print("Compute jaccard similarity", flush=True)
|
||||
if args.jaccard_parallel:
|
||||
find_pair_urls_parallel(args, lshcache, url_doc)
|
||||
else:
|
||||
find_pair_urls_sequential(args, lshcache, url_doc)
|
||||
|
||||
print('done :-)')
|
|
@ -1,305 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for OpenAI GPT."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import regex as re
|
||||
|
||||
try:
|
||||
from functools import lru_cache
|
||||
except ImportError:
|
||||
# Just a dummy decorator to get the checks to run on python2
|
||||
# because honestly I don't want to support a byte-level unicode BPE
|
||||
# tokenizer on python 2 right now.
|
||||
def lru_cache():
|
||||
return lambda func: func
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
||||
}
|
||||
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||
}
|
||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
'gpt2': 1024,
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
_chr = unichr if sys.version_info[0] == 2 else chr
|
||||
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
|
||||
list(range(ord("®"), ord("ÿ") + 1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [_chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class GPT2Tokenizer(object):
|
||||
"""
|
||||
GPT-2 BPE tokenizer. Peculiarities:
|
||||
- Byte-level BPE
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
special_tokens_file = None
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
|
||||
if not os.path.exists(special_tokens_file):
|
||||
special_tokens_file = None
|
||||
else:
|
||||
logger.info("loading special tokens file {}".format(special_tokens_file))
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
from cached_path import cached_path
|
||||
resolved_vocab_file = cached_path(vocab_file)
|
||||
resolved_merges_file = cached_path(merges_file)
|
||||
except EnvironmentError:
|
||||
logger.error("Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||
"at this path or url.".format(pretrained_model_name_or_path,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
pretrained_model_name_or_path, vocab_file, merges_file))
|
||||
return None
|
||||
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
logger.info("loading merges file {}".format(merges_file))
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(vocab_file, resolved_vocab_file))
|
||||
logger.info("loading merges file {} from cache at {}".format(merges_file, resolved_merges_file))
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||
# than the number of positional embeddings
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
if special_tokens_file and 'special_tokens' not in kwargs:
|
||||
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
else:
|
||||
special_tokens = kwargs.pop('special_tokens', [])
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.encoder = json.load(open(vocab_file))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for
|
||||
# capitalized versions of contractions
|
||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
self.set_special_tokens(special_tokens)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
""" Add a list of additional tokens to the encoder.
|
||||
The additional tokens are indexed starting from the last index of the
|
||||
current vocabulary in the order of the `special_tokens` list.
|
||||
"""
|
||||
if not special_tokens:
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
return
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
|
||||
logger.info("Special tokens {}".format(self.special_tokens))
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except BaseException:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
if sys.version_info[0] == 2:
|
||||
token = ''.join(self.byte_encoder[ord(b)] for b in token)
|
||||
else:
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a sequence of tokens into ids using the vocab. """
|
||||
ids = []
|
||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||
if tokens in self.special_tokens:
|
||||
return self.special_tokens[tokens]
|
||||
else:
|
||||
return self.encoder.get(tokens, 0)
|
||||
for token in tokens:
|
||||
if token in self.special_tokens:
|
||||
ids.append(self.special_tokens[token])
|
||||
else:
|
||||
ids.append(self.encoder.get(token, 0))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning("Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len))
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
if i in self.special_tokens_decoder:
|
||||
if not skip_special_tokens:
|
||||
tokens.append(self.special_tokens_decoder[i])
|
||||
else:
|
||||
tokens.append(self.decoder[i])
|
||||
return tokens
|
||||
|
||||
def encode(self, text):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write(u'#version: 0.2\n')
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||||
index = token_index
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
|
||||
index = len(self.encoder)
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
||||
index = token_index
|
||||
writer.write(token + u'\n')
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
|
@ -1,85 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
print('grouping duplicate urls ...')
|
||||
|
||||
input = sys.argv[1]
|
||||
output = sys.argv[2]
|
||||
if len(sys.argv) > 3:
|
||||
jaccard_similarity_threshold = float(sys.argv[3])
|
||||
else:
|
||||
jaccard_similarity_threshold = 0.7
|
||||
|
||||
url_to_index = {}
|
||||
index_to_urls = []
|
||||
counter = 0
|
||||
start_time = time.time()
|
||||
with open(input, 'r') as f:
|
||||
for line in f:
|
||||
counter += 1
|
||||
myjson = json.loads(line)
|
||||
urls = []
|
||||
for main_url in myjson.keys():
|
||||
urls.append(main_url)
|
||||
for value in myjson[main_url]:
|
||||
for other_url, js in value.items():
|
||||
if js >= jaccard_similarity_threshold:
|
||||
urls.append(other_url)
|
||||
current_index = -1
|
||||
other_indices = set()
|
||||
for url in urls:
|
||||
if url in url_to_index:
|
||||
if current_index == -1:
|
||||
current_index = url_to_index[url]
|
||||
elif current_index != url_to_index[url]:
|
||||
other_indices.add(url_to_index[url])
|
||||
if current_index == -1:
|
||||
current_index = len(index_to_urls)
|
||||
index_to_urls.append(set())
|
||||
for url in urls:
|
||||
url_to_index[url] = current_index
|
||||
index_to_urls[current_index].add(url)
|
||||
for index in other_indices:
|
||||
for url in index_to_urls[index]:
|
||||
index_to_urls[current_index].add(url)
|
||||
url_to_index[url] = current_index
|
||||
index_to_urls[index] = None
|
||||
|
||||
if counter % 100000 == 0:
|
||||
print(' > processed {} lines in {} seconds ...'.format(counter, time.time() - start_time))
|
||||
|
||||
total_remove = 0
|
||||
total_remain = 0
|
||||
for urls in index_to_urls:
|
||||
if urls is not None:
|
||||
if len(urls) > 1:
|
||||
total_remove += (len(urls) - 1)
|
||||
total_remain += 1
|
||||
print('out of {} urls, only {} are unique and {} should be removed'.format(total_remove + total_remain,
|
||||
total_remain, total_remove))
|
||||
|
||||
with open(output, 'wb') as f:
|
||||
for i, urls in enumerate(index_to_urls):
|
||||
if urls is not None:
|
||||
if len(urls) > 1:
|
||||
myjson = json.dumps({str(i): list(urls)}, ensure_ascii=False)
|
||||
f.write(myjson.encode('utf-8'))
|
||||
f.write('\n'.encode('utf-8'))
|
|
@ -1,64 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
url_filename = sys.argv[1]
|
||||
data_filename = sys.argv[2]
|
||||
output_filename = sys.argv[3]
|
||||
|
||||
urls = set()
|
||||
with open(url_filename, 'r') as f:
|
||||
for line in f:
|
||||
myjson = json.loads(line)
|
||||
for key in myjson:
|
||||
this_urls = myjson[key]
|
||||
for i in range(1, len(this_urls)):
|
||||
urls.add(this_urls[i])
|
||||
print('will be removing {} urls'.format(len(urls)), flush=True)
|
||||
|
||||
written_docs = 0
|
||||
removed_docs = 0
|
||||
removed_chars = 0
|
||||
start_time = time.time()
|
||||
with open(output_filename, 'wb') as fout:
|
||||
with open(data_filename, 'r') as fin:
|
||||
for line in fin:
|
||||
try:
|
||||
myjson = json.loads(line)
|
||||
url = myjson['url']
|
||||
if url in urls:
|
||||
print('removing', myjson)
|
||||
removed_docs += 1
|
||||
removed_chars += len(myjson['text'])
|
||||
continue
|
||||
myjson = json.dumps(myjson, ensure_ascii=False)
|
||||
fout.write(myjson.encode('utf-8'))
|
||||
fout.write('\n'.encode('utf-8'))
|
||||
written_docs += 1
|
||||
if written_docs % 10000 == 0:
|
||||
print(' [PROCESSED] time (s): {:.2f} | written: {} '
|
||||
'| removed: {} (char: {})'.format(time.time() - start_time, written_docs, removed_docs,
|
||||
removed_chars))
|
||||
except Exception as e:
|
||||
print('[SKIPPING]', line, e)
|
||||
|
||||
print(' [PROCESSED] time (s): {:.2f} | written: {} '
|
||||
'| removed: {} (char: {})'.format(time.time() - start_time, written_docs, removed_docs, removed_chars))
|
||||
print('done :-)')
|
|
@ -1,36 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append('..')
|
||||
|
||||
from gpt2_tokenization import GPT2Tokenizer
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
|
||||
def __init__(self, cache_dir=None):
|
||||
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
|
||||
self.tokenizer.max_len = int(1e12)
|
||||
self.eod_token = self.tokenizer.encoder['<|endoftext|>']
|
||||
assert self.eod_token < 65535, 'vocab size will not fit in uint16'
|
||||
print('> GPT2 tokenizer with {} vocab size and eod token {} ...'.format(len(self.tokenizer.encoder),
|
||||
self.eod_token))
|
||||
|
||||
def tokenize_document(self, document):
|
||||
tokens = self.tokenizer.encode(document)
|
||||
tokens.append(self.eod_token)
|
||||
return tokens
|
|
@ -1,347 +0,0 @@
|
|||
# Code taken in large part from https://github.com/jcpeterson/openwebtext
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import json
|
||||
import multiprocessing as mpl
|
||||
import os
|
||||
import os.path as op
|
||||
import sqlite3
|
||||
import tarfile
|
||||
import time
|
||||
import warnings
|
||||
from glob import glob
|
||||
from hashlib import sha256
|
||||
|
||||
import tldextract
|
||||
from scrapers import bs4_scraper, newspaper_scraper, raw_scraper
|
||||
# for backward compatibility
|
||||
from six.moves.urllib.request import urlopen
|
||||
from tqdm import tqdm
|
||||
from utils import chunks, extract_month, linecount, mkdir
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("url_file", type=str)
|
||||
parser.add_argument(
|
||||
"--save_uncompressed",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="whether to save the raw txt files to disk",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default='raw.json',
|
||||
help="where to save the output json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="scraped",
|
||||
help="which folder in the working directory to use for output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_procs",
|
||||
type=int,
|
||||
default=10,
|
||||
help="how many processes (cores) to use for parallel scraping",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="maximum scrape time for a single URL; -1 means no limit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_urls",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="maximum # of URLs to scrape; mostly for debugging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk_size",
|
||||
type=int,
|
||||
default=100,
|
||||
help="how many URLs to scrape before saving to archive",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scraper",
|
||||
type=str,
|
||||
default="newspaper",
|
||||
choices=["raw", "bs4", "newspaper"],
|
||||
help="which text/content scraper to use; raw is html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compress",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="whether to output scraped content as compressed archives",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compress_fmt",
|
||||
type=str,
|
||||
default="xz",
|
||||
choices=["xz", "bz2", "gz"],
|
||||
help="which archive format to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scraper_memoize",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="whether to use cache for newspaper",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show_warnings",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="whether to show warnings in general during scraping",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sqlite_meta",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="whether to use sqlite for storing meta. if false, json will be used instead",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.show_warnings:
|
||||
# avoid lots of datetime warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def load_urls(fh, max_urls=-1):
|
||||
url_entries = enumerate(fh)
|
||||
if max_urls != -1:
|
||||
url_entries = list(url_entries)[:max_urls]
|
||||
return url_entries
|
||||
|
||||
|
||||
def vet_link(link):
|
||||
# check if server responds with non-200 status code or link points to a
|
||||
# non-html file
|
||||
link_type, link_status = "", -1
|
||||
try:
|
||||
info = urlopen(link)
|
||||
link_type = info.headers["Content-Type"]
|
||||
link_status = info.status
|
||||
except:
|
||||
pass
|
||||
|
||||
# we want "text/html" only!
|
||||
is_good_link = False
|
||||
if "text/html" in link_type and link_status == 200:
|
||||
is_good_link = True
|
||||
|
||||
return is_good_link, link_type
|
||||
|
||||
|
||||
def download(url_entry,
|
||||
scraper=args.scraper,
|
||||
save_uncompressed=args.save_uncompressed,
|
||||
memoize=args.scraper_memoize,
|
||||
arch_meta=not args.sqlite_meta):
|
||||
|
||||
uid, url = url_entry
|
||||
url = url.strip()
|
||||
fid = "{:07d}-{}".format(uid, sha256(url.encode()).hexdigest())
|
||||
|
||||
data_dir = mkdir(op.join(args.output_dir, "data"))
|
||||
text_fp = op.join(data_dir, "{}.txt".format(fid))
|
||||
|
||||
if arch_meta:
|
||||
meta_dir = mkdir(op.join(args.output_dir, "meta"))
|
||||
meta_fp = op.join(meta_dir, "{}.json".format(fid))
|
||||
|
||||
# already downloaded!
|
||||
if op.exists(text_fp):
|
||||
return
|
||||
|
||||
# is_good_link, link_type = vet_link(url)
|
||||
# if not is_good_link:
|
||||
# return
|
||||
|
||||
if scraper == "bs4":
|
||||
scrape = bs4_scraper
|
||||
elif scraper == "newspaper":
|
||||
scrape = newspaper_scraper
|
||||
elif scraper == "raw":
|
||||
scrape = raw_scraper
|
||||
|
||||
text, meta = scrape(url, memoize)
|
||||
|
||||
ext = tldextract.extract(url)
|
||||
domain = '.'.join([x for x in ext if x])
|
||||
meta["domain"] = domain
|
||||
|
||||
if text is None or text.strip() == "":
|
||||
return ("", meta, fid, uid)
|
||||
|
||||
if save_uncompressed:
|
||||
with open(text_fp, "w") as out:
|
||||
out.write(text)
|
||||
if arch_meta:
|
||||
with open(meta_fp, "w") as out:
|
||||
json.dump(meta, out)
|
||||
|
||||
return (text, meta, fid, uid)
|
||||
|
||||
|
||||
def archive_chunk(cid, cdata, out_dir, fmt, arch_meta):
|
||||
mkdir(out_dir)
|
||||
texts, metas, fids, uids = zip(*cdata)
|
||||
|
||||
data_tar = op.join(out_dir, "{}_data.{}".format(cid, fmt))
|
||||
if arch_meta:
|
||||
meta_tar = op.join(out_dir, "{}_meta.{}".format(cid, fmt))
|
||||
tar_fps, texts, exts = [data_tar, meta_tar], [texts, metas], ["txt", "json"]
|
||||
else:
|
||||
tar_fps, texts, exts = [data_tar], [texts], ["txt"]
|
||||
|
||||
doc_count = 0
|
||||
docs_counted = False
|
||||
for tar_fp, txts, ext in zip(tar_fps, texts, exts):
|
||||
with tarfile.open(tar_fp, "w:" + fmt) as tar:
|
||||
for f, fid in zip(txts, fids):
|
||||
if f == "":
|
||||
continue
|
||||
else:
|
||||
if not docs_counted:
|
||||
doc_count += 1
|
||||
|
||||
if ext == "json":
|
||||
f = json.dumps(f)
|
||||
|
||||
f = f.encode("utf-8")
|
||||
t = tarfile.TarInfo("{}.{}".format(fid, ext))
|
||||
t.size = len(f)
|
||||
tar.addfile(t, io.BytesIO(f))
|
||||
docs_counted = True
|
||||
|
||||
return doc_count
|
||||
|
||||
|
||||
def load_state(url_file):
|
||||
ckptfile = url_file + '.ckpt'
|
||||
if op.exists(ckptfile):
|
||||
with open(ckptfile) as fp:
|
||||
r = fp.read()
|
||||
if r == '':
|
||||
return 0
|
||||
else:
|
||||
return int(r)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def save_state(url_file, cid):
|
||||
ckptfile = url_file + '.ckpt'
|
||||
with open(ckptfile, 'w') as fp:
|
||||
fp.write(str(cid))
|
||||
|
||||
|
||||
def sqlite_conn():
|
||||
conn = sqlite3.connect('metadata.db')
|
||||
conn.execute('''
|
||||
CREATE TABLE IF NOT EXISTS metadata (
|
||||
fid char(64) not null primary key,
|
||||
url varchar(2048) not null,
|
||||
domain varchar(255) not null,
|
||||
word_count int null,
|
||||
elapsed int null,
|
||||
scraper varchar(255) not null,
|
||||
success boolean not null
|
||||
);
|
||||
''')
|
||||
conn.execute('''
|
||||
CREATE INDEX IF NOT EXISTS ix_meta_url ON metadata(url);
|
||||
''')
|
||||
conn.execute('''
|
||||
CREATE INDEX IF NOT EXISTS ix_meta_domain ON metadata(domain);
|
||||
''')
|
||||
|
||||
return conn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.sqlite_meta:
|
||||
conn = sqlite_conn()
|
||||
cur = conn.cursor()
|
||||
|
||||
start_elem = load_state(args.url_file)
|
||||
start_chnk = start_elem // args.chunk_size
|
||||
|
||||
f_json = open(args.output, "w")
|
||||
|
||||
# URLs we haven't scraped yet (if first run, all URLs in file)
|
||||
with open(args.url_file) as fh:
|
||||
url_entries = load_urls(fh, args.max_urls)
|
||||
|
||||
pool = mpl.Pool(args.n_procs)
|
||||
total = linecount(args.url_file) // args.chunk_size
|
||||
print('Total chunks: ', total)
|
||||
chunk_iterator = tqdm(enumerate(chunks(url_entries, args.chunk_size, start_elem)), total=total)
|
||||
|
||||
# display already-downloaded chunks on progress bar
|
||||
chunk_iterator.update(start_chnk)
|
||||
|
||||
# process one "chunk" of args.chunk_size URLs at a time
|
||||
for i, chunk in chunk_iterator:
|
||||
cid = start_chnk + i + 1
|
||||
|
||||
tqdm.write("Downloading chunk {}".format(cid))
|
||||
t1 = time.time()
|
||||
|
||||
if args.timeout > 0:
|
||||
# imap as iterator allows .next() w/ timeout.
|
||||
# ordered version doesn't seem to work correctly.
|
||||
# for some reason, you CANNOT track j or chunk[j] in the loop,
|
||||
# so don't add anything else to the loop below!
|
||||
# confusingly, chunksize below is unrelated to our chunk_size
|
||||
chunk_iter = pool.imap_unordered(download, chunk, chunksize=1)
|
||||
cdata = []
|
||||
for j in range(len(chunk)):
|
||||
try:
|
||||
result = chunk_iter.next(timeout=args.timeout)
|
||||
cdata.append(result)
|
||||
except mpl.TimeoutError:
|
||||
tqdm.write(" --- Timeout Error --- ")
|
||||
else:
|
||||
cdata = list(pool.imap(download, chunk, chunksize=1))
|
||||
|
||||
tqdm.write("{} / {} downloads timed out".format(len(chunk) - len(cdata), len(chunk)))
|
||||
tqdm.write("Chunk time: {} seconds".format(time.time() - t1))
|
||||
|
||||
# write metadata to sqlite
|
||||
if args.sqlite_meta:
|
||||
for text, meta, fid, _ in filter(lambda x: x, cdata):
|
||||
if text:
|
||||
params = (fid, meta["url"], meta["domain"], meta["elapsed"], meta["word_count"],
|
||||
meta["scraper"], True)
|
||||
else:
|
||||
params = (fid, meta["url"], meta["domain"], None, None, meta["scraper"], False)
|
||||
cur.execute(
|
||||
"insert or ignore into metadata (fid, url, domain, elapsed, word_count, scraper, success) values (?, ?, ?, ?, ?, ?, ?)",
|
||||
params)
|
||||
conn.commit()
|
||||
|
||||
dump_chunk = []
|
||||
for text, meta, fid, _ in filter(lambda x: x, cdata):
|
||||
if text:
|
||||
line_json = {"text": text, "url": meta["url"]}
|
||||
dump_chunk.append(json.dumps(line_json) + '\n')
|
||||
f_json.writelines(dump_chunk)
|
||||
|
||||
# archive and save this chunk to file
|
||||
if args.compress:
|
||||
tqdm.write("Compressing...")
|
||||
t2 = time.time()
|
||||
count = archive_chunk(cid, cdata, args.output_dir, args.compress_fmt, not args.sqlite_meta)
|
||||
tqdm.write("Archive created in {} seconds".format(time.time() - t2))
|
||||
tqdm.write("{} out of {} URLs yielded content\n".format(len(list(filter(lambda x: x and x[0], cdata))),
|
||||
len(chunk)))
|
||||
|
||||
save_state(args.url_file, cid * args.chunk_size)
|
||||
f_json.close()
|
||||
print("Done!")
|
|
@ -1,58 +0,0 @@
|
|||
import hashlib
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import newspaper
|
||||
import tldextract
|
||||
import tqdm
|
||||
from filter import should_exclude
|
||||
|
||||
hash = hashlib.sha256
|
||||
|
||||
try:
|
||||
os.mkdir('data')
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
|
||||
def dl(url):
|
||||
url = url.strip()
|
||||
|
||||
if should_exclude(url):
|
||||
return
|
||||
|
||||
ext = tldextract.extract(url)
|
||||
domain = '.'.join([x for x in ext if x])
|
||||
|
||||
fname = 'data/{}-{}.txt'.format(domain, hash(url.encode()).hexdigest())
|
||||
if os.path.isfile(fname):
|
||||
return
|
||||
# print('Downloading', url)
|
||||
try:
|
||||
article = newspaper.Article(url, fetch_images=False)
|
||||
article.download()
|
||||
article.parse()
|
||||
except newspaper.article.ArticleException:
|
||||
# print('Dead link:', url)
|
||||
return
|
||||
|
||||
|
||||
# traceback.print_exc()
|
||||
|
||||
text = article.text
|
||||
|
||||
if text.strip() == '':
|
||||
# print('Empty')
|
||||
return
|
||||
|
||||
with open(fname, 'w') as out:
|
||||
out.write(text)
|
||||
|
||||
if __name__ == '__main__':
|
||||
p = mp.Pool(100) # num of download threads
|
||||
with open('urls.txt') as fh:
|
||||
urls = list(fh)
|
||||
|
||||
list(tqdm.tqdm(p.imap(dl, urls), total=len(urls)))
|
||||
print('Done!')
|
|
@ -1,110 +0,0 @@
|
|||
import re
|
||||
|
||||
import tldextract
|
||||
import tqdm
|
||||
from utils import linecount
|
||||
|
||||
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
|
||||
url_regex = re.compile(
|
||||
r'^(?:http)s?://' # http:// or https://
|
||||
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
|
||||
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
|
||||
r'(?::\d+)?' # optional port
|
||||
r'(?:/?|[/?]\S+)$',
|
||||
re.IGNORECASE)
|
||||
|
||||
# domains that aren't scraper friendly. do not include subdomains!
|
||||
exclude_domains = set([
|
||||
# image & video hosting sites
|
||||
'imgur.com',
|
||||
'redd.it',
|
||||
'instagram.com',
|
||||
'discord.gg',
|
||||
'gfycat.com',
|
||||
'giphy.com',
|
||||
'reddituploads.com',
|
||||
'redditmedia.com',
|
||||
'twimg.com',
|
||||
'sli.mg',
|
||||
'magaimg.net',
|
||||
'flickr.com',
|
||||
'imgflip.com',
|
||||
'youtube.com',
|
||||
'youtu.be',
|
||||
'youtubedoubler.com',
|
||||
'vimeo.com',
|
||||
'twitch.tv',
|
||||
'streamable.com',
|
||||
'bandcamp.com',
|
||||
'soundcloud.com',
|
||||
|
||||
# not scraper friendly
|
||||
'reddit.com',
|
||||
'gyazo.com',
|
||||
'github.com',
|
||||
'xkcd.com',
|
||||
'twitter.com',
|
||||
'spotify.com',
|
||||
'itunes.apple.com',
|
||||
'facebook.com',
|
||||
'gunprime.com',
|
||||
'strawpoll.me',
|
||||
'voyagefusion.com',
|
||||
'rollingstone.com',
|
||||
'google.com',
|
||||
'timeanddate.com',
|
||||
'walmart.com',
|
||||
'roanoke.com',
|
||||
'spotrac.com',
|
||||
|
||||
# original paper excluded wikipedia
|
||||
'wikipedia.org',
|
||||
|
||||
# lots of top posts for this one
|
||||
'battleforthenet.com',
|
||||
])
|
||||
|
||||
exclude_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.gifv', '.pdf', '.mp4', '.mp3', '.ogv', '.webm', '.doc',
|
||||
'.docx', '.log', '.csv', '.dat', '.iso', '.bin', '.exe', '.apk', '.jar', '.app', '.ppt', '.pps',
|
||||
'.pptx', '.xml', '.gz', '.xz', '.bz2', '.tgz', '.tar', '.zip', '.wma', '.mov', '.wmv', '.3gp',
|
||||
'.svg', '.rar', '.wav', '.avi', '.7z')
|
||||
|
||||
|
||||
def should_exclude(url):
|
||||
|
||||
ext = tldextract.extract(url)
|
||||
domain = '.'.join([x for x in ext if x])
|
||||
basedomain = '.'.join(ext[-2:])
|
||||
|
||||
# Ignore non-URLs
|
||||
if len(url) <= 8 or ' ' in url or re.match(url_regex, url) is None:
|
||||
return True
|
||||
|
||||
# Ignore excluded domains
|
||||
if basedomain in exclude_domains or domain in exclude_domains:
|
||||
return True
|
||||
|
||||
# Ignore case-insensitive matches for excluded extensions
|
||||
if url.lower().split('?')[0].endswith(exclude_extensions):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
url_file = 'urls.txt'
|
||||
filtered_file = 'urls-filtered.txt'
|
||||
|
||||
with open(url_file) as urls, open(filtered_file, 'w') as out:
|
||||
url_len = linecount(url_file)
|
||||
print("URL file is", url_len, "URLs long.")
|
||||
url_set = set()
|
||||
for line in tqdm.tqdm(urls, total=url_len):
|
||||
if len(line.strip()) == 0:
|
||||
continue # Skip whitespace-only lines
|
||||
line = line.strip().split()[0] # Drop any components following whitespace
|
||||
if should_exclude(line):
|
||||
continue
|
||||
url_set.add(line)
|
||||
for line in tqdm.tqdm(url_set):
|
||||
out.write(line + '\n')
|
|
@ -1,32 +0,0 @@
|
|||
import datetime
|
||||
|
||||
import praw
|
||||
import psaw
|
||||
import tqdm
|
||||
|
||||
api = psaw.PushshiftAPI()
|
||||
|
||||
# all posts until the end of 2017
|
||||
end_time = int(datetime.datetime(2018, 1, 1).timestamp())
|
||||
|
||||
query = api.search_submissions(before=end_time,
|
||||
filter=['url', 'score'],
|
||||
sort='desc',
|
||||
score='>2',
|
||||
is_self=False,
|
||||
over_18=False)
|
||||
|
||||
with tqdm.tqdm() as pbar:
|
||||
# download links from submissions
|
||||
with open('urls.txt', 'w') as fh:
|
||||
for subm in query:
|
||||
url = subm.url
|
||||
|
||||
# weird issue with psaw/pushshift that breaks score=">2"
|
||||
if subm.score < 3:
|
||||
continue
|
||||
#print(subm.score)
|
||||
# pbar.write(str(datetime.datetime.fromtimestamp(subm.created_utc)))
|
||||
pbar.update(1)
|
||||
fh.write(url + '\n')
|
||||
fh.flush()
|
|
@ -1,121 +0,0 @@
|
|||
# Code taken in large part from https://github.com/jcpeterson/openwebtext
|
||||
|
||||
import time
|
||||
import unicodedata
|
||||
|
||||
import bs4
|
||||
import newspaper
|
||||
from filter import should_exclude
|
||||
from htmlmin import minify
|
||||
from lxml.html.clean import Cleaner
|
||||
|
||||
|
||||
def find_and_filter_tag(tag, soup):
|
||||
"""tag specific filter logic"""
|
||||
|
||||
candidates = soup.find_all(tag)
|
||||
candidates = [unicodedata.normalize("NFKD", x.string) for x in candidates if x.string is not None]
|
||||
|
||||
if tag == "p":
|
||||
candidates = [y.strip() for y in candidates if len(y.split(" ")) >= 4]
|
||||
count = sum(len(y.split(" ")) for y in candidates)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return (candidates, count)
|
||||
|
||||
|
||||
def raw_scraper(url, memoize):
|
||||
t1 = time.time()
|
||||
if should_exclude(url):
|
||||
# heuristic to make downloading faster
|
||||
return None, {
|
||||
"url": url,
|
||||
"scraper": "raw",
|
||||
}
|
||||
|
||||
try:
|
||||
cleaner = Cleaner()
|
||||
cleaner.javascript = True
|
||||
cleaner.style = True
|
||||
article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize)
|
||||
article.download()
|
||||
html = minify(article.html)
|
||||
html = cleaner.clean_html(html)
|
||||
article.parse()
|
||||
except:
|
||||
return None, {
|
||||
"url": url,
|
||||
"scraper": "raw",
|
||||
}
|
||||
if article.text == "":
|
||||
return None, {
|
||||
"url": url,
|
||||
"scraper": "raw",
|
||||
}
|
||||
|
||||
metadata = {"url": url, "elapsed": time.time() - t1, "scraper": "raw"}
|
||||
return html, metadata
|
||||
|
||||
|
||||
def newspaper_scraper(url, memoize):
|
||||
t1 = time.time()
|
||||
if should_exclude(url):
|
||||
# heuristic to make downloading faster
|
||||
return None, {
|
||||
"url": url,
|
||||
"scraper": "newspaper",
|
||||
}
|
||||
|
||||
try:
|
||||
article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize)
|
||||
article.download()
|
||||
article.parse()
|
||||
text = article.text
|
||||
count = len(text.split())
|
||||
except:
|
||||
return None, {
|
||||
"url": url,
|
||||
"scraper": "newspaper",
|
||||
}
|
||||
|
||||
metadata = {
|
||||
"url": url,
|
||||
"word_count": count,
|
||||
"elapsed": time.time() - t1,
|
||||
"scraper": "newspaper",
|
||||
}
|
||||
return text, metadata
|
||||
|
||||
|
||||
def bs4_scraper(url, memoize):
|
||||
t1 = time.time()
|
||||
if should_exclude(url):
|
||||
# heuristic to make downloading faster
|
||||
return None, {
|
||||
"url": url,
|
||||
"scraper": "bs4",
|
||||
}
|
||||
|
||||
try:
|
||||
article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize)
|
||||
article.download()
|
||||
html = article.html
|
||||
soup = bs4.BeautifulSoup(html, "lxml")
|
||||
text, count = find_and_filter_tag("p", soup)
|
||||
# DDB: keep text as a single string for consistency with
|
||||
# newspaper_scraper
|
||||
text = " ".join(text)
|
||||
except:
|
||||
return None, {
|
||||
"url": url,
|
||||
"scraper": "bs4",
|
||||
}
|
||||
|
||||
metadata = {
|
||||
"url": url,
|
||||
"word_count": count,
|
||||
"elapsed": time.time() - t1,
|
||||
"scraper": "bs4",
|
||||
}
|
||||
return text, metadata
|
|
@ -1,62 +0,0 @@
|
|||
# Code taken in large part from https://github.com/jcpeterson/openwebtext
|
||||
|
||||
import collections
|
||||
import os
|
||||
import os.path as op
|
||||
import re
|
||||
import tarfile
|
||||
|
||||
|
||||
def extract_month(url_file_name):
|
||||
month_re = r"(RS_.*2\d{3}-\d{2})"
|
||||
month = op.split(url_file_name)[-1]
|
||||
month = re.match(month_re, month).group()
|
||||
return month
|
||||
|
||||
|
||||
def chunks(l, n, s=0):
|
||||
"""Yield successive n-sized chunks from l, skipping the first s chunks."""
|
||||
if isinstance(l, collections.Iterable):
|
||||
chnk = []
|
||||
for i, elem in enumerate(l):
|
||||
if i < s:
|
||||
continue
|
||||
|
||||
chnk.append(elem)
|
||||
if len(chnk) == n:
|
||||
yield chnk
|
||||
chnk = []
|
||||
if len(chnk) != 0:
|
||||
yield chnk
|
||||
|
||||
else:
|
||||
for i in range(s, len(l), n):
|
||||
yield l[i:i + n]
|
||||
|
||||
|
||||
def extract_archive(archive_fp, outdir="."):
|
||||
with tarfile.open(archive_fp, "r") as tar:
|
||||
tar.extractall(outdir)
|
||||
return outdir
|
||||
|
||||
|
||||
def mkdir(fp):
|
||||
try:
|
||||
os.makedirs(fp)
|
||||
except FileExistsError:
|
||||
pass
|
||||
return fp
|
||||
|
||||
|
||||
def linecount(filename):
|
||||
f = open(filename, 'rb')
|
||||
lines = 0
|
||||
buf_size = 1024 * 1024
|
||||
read_f = f.raw.read
|
||||
|
||||
buf = read_f(buf_size)
|
||||
while buf:
|
||||
lines += buf.count(b'\n')
|
||||
buf = read_f(buf_size)
|
||||
|
||||
return lines
|
|
@ -1,143 +0,0 @@
|
|||
import contextlib
|
||||
import os
|
||||
|
||||
import torch
|
||||
from dataset.webtext import WebtextDataset
|
||||
from titans.loss.lm_loss import GPTLMLoss
|
||||
|
||||
import colossalai
|
||||
import colossalai.utils as utils
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn import LinearWarmupLR
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import is_using_pp
|
||||
from colossalai.utils.timer import MultiTimer
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
|
||||
|
||||
def calc_local_model_size(model: torch.nn.Module):
|
||||
numel_per_device = 0
|
||||
for p in model.parameters():
|
||||
numel_per_device += p.numel()
|
||||
return numel_per_device
|
||||
|
||||
|
||||
def main():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument('--from_torch', default=False, action='store_true')
|
||||
args = parser.parse_args()
|
||||
disable_existing_loggers()
|
||||
if args.from_torch:
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
else:
|
||||
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
logger.info('Build data loader', ranks=[0])
|
||||
train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN)
|
||||
train_dataloader = utils.get_dataloader(train_ds,
|
||||
seed=42,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
|
||||
logger.info('Build model', ranks=[0])
|
||||
use_pipeline = is_using_pp()
|
||||
use_interleaved = hasattr(gpc.config.model, 'num_chunks')
|
||||
num_chunks = getattr(gpc.config.model, 'num_chunks', 1)
|
||||
use_zero3 = hasattr(gpc.config, 'zero')
|
||||
|
||||
if not use_pipeline:
|
||||
ctx = contextlib.nullcontext()
|
||||
if use_zero3:
|
||||
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True)
|
||||
with ctx:
|
||||
model = gpc.config.model.pop('type')(**gpc.config.model)
|
||||
else:
|
||||
pipelinable = PipelinableContext()
|
||||
with pipelinable:
|
||||
model = gpc.config.model.pop('type')(**gpc.config.model)
|
||||
|
||||
def mask_function(attention_mask=None):
|
||||
if attention_mask is not None:
|
||||
batch_size = gpc.config.BATCH_SIZE // gpc.config.NUM_MICRO_BATCHES
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = col_nn.partition_batch(attention_mask)
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
return attention_mask
|
||||
|
||||
# GPT2_small exec_seq
|
||||
# (lyl)TODO: The exec_seq for gpt3 will be added here and to_layer_list should be more friendly to use.
|
||||
exec_seq = ['embed', mask_function, 'blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'blocks.4', 'blocks.5', (mask_function, "front"), \
|
||||
'blocks.6', 'blocks.7', 'blocks.8', 'blocks.9', 'blocks.10', 'blocks.11', 'norm', 'head']
|
||||
pipelinable.to_layer_list(exec_seq)
|
||||
ctx = contextlib.nullcontext()
|
||||
# (lyl)TODO: Zero context and pipelinable context should be integrated into one context.
|
||||
if use_zero3:
|
||||
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True)
|
||||
with ctx:
|
||||
model = pipelinable.partition(num_chunks, gpc.pipeline_parallel_size,
|
||||
gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||
|
||||
if use_zero3:
|
||||
numel = ctx.model_numel_tensor.item()
|
||||
else:
|
||||
numel = calc_local_model_size(model)
|
||||
|
||||
tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \
|
||||
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)
|
||||
|
||||
criterion = getattr(gpc.config, 'loss_fn', None)
|
||||
if criterion is not None:
|
||||
criterion = criterion.type()
|
||||
else:
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
logger.info('Build optimizer', ranks=[0])
|
||||
optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer)
|
||||
|
||||
lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)
|
||||
|
||||
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,
|
||||
optimizer,
|
||||
criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
global_batch_size = gpc.config.BATCH_SIZE * \
|
||||
gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
|
||||
logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
|
||||
|
||||
timier = MultiTimer()
|
||||
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timier)
|
||||
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
|
||||
hooks.LogMetricByStepHook(),
|
||||
hooks.LogMemoryByEpochHook(logger),
|
||||
]
|
||||
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
return_output_label=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,161 @@
|
|||
from functools import partial
|
||||
from time import time
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def gpt2_medium(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_xl(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_10b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
return psutil.Process().memory_info().rss / 1024**2
|
||||
|
||||
|
||||
def get_gpu_mem():
|
||||
return torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
|
||||
def get_mem_info(prefix=''):
|
||||
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||
|
||||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||
|
||||
|
||||
def main():
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
NUM_STEPS = 10
|
||||
PLACEMENT_POLICY = 'auto'
|
||||
disable_existing_loggers()
|
||||
colossalai.launch_from_torch(config={})
|
||||
pg = ProcessGroup()
|
||||
logger = get_dist_logger()
|
||||
|
||||
logger.info(get_mem_info(), ranks=[0])
|
||||
# build GPT model
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = gpt2_medium(checkpoint=True)
|
||||
numel = sum([p.numel() for p in model.parameters()])
|
||||
logger.info(f'Model numel: {numel}', ranks=[0])
|
||||
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
||||
|
||||
cai_version = colossalai.__version__
|
||||
logger.info(f'using Colossal-AI version {cai_version}')
|
||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=PLACEMENT_POLICY,
|
||||
pin_memory=True,
|
||||
search_range_mb=32)
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=True,
|
||||
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
# optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
|
||||
model.train()
|
||||
for n in range(NUM_STEPS):
|
||||
# we just use randomly generated data here
|
||||
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
||||
optimizer.zero_grad()
|
||||
start = time()
|
||||
outputs = model(input_ids, attn_mask)
|
||||
loss = criterion(outputs, input_ids)
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||
optimizer.backward(loss)
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
|
||||
optimizer.step()
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
||||
step_time = time() - start
|
||||
logger.info(
|
||||
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
|
||||
ranks=[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -22,6 +22,9 @@ The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI)
|
|||
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
|
||||
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
|
||||
|
||||
## Our Modifications
|
||||
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
|
||||
|
||||
## Quick Start
|
||||
You can launch training by using the following bash script
|
||||
|
||||
|
|
Loading…
Reference in New Issue