mirror of https://github.com/hpcaitech/ColossalAI
[example] integrate seq-parallel tutorial with CI (#2463)
parent
8e85d2440a
commit
8b7495dd54
|
@ -114,6 +114,13 @@ class FusedScaleMaskSoftmax(nn.Module):
|
|||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
self.scale = scale
|
||||
|
||||
try:
|
||||
from colossalai._C import scaled_masked_softmax
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
||||
self.scaled_masked_softmax = scaled_masked_softmax
|
||||
|
||||
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
|
||||
|
||||
def forward(self, input, mask):
|
||||
|
@ -178,11 +185,5 @@ class FusedScaleMaskSoftmax(nn.Module):
|
|||
|
||||
return probs
|
||||
|
||||
@staticmethod
|
||||
def get_batch_per_block(sq, sk, b, np):
|
||||
try:
|
||||
import colossalai._C.scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
def get_batch_per_block(self, sq, sk, b, np):
|
||||
return self.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
# Comparison of Large Batch Training Optimization
|
||||
# Large Batch Training Optimization
|
||||
|
||||
## Table of contents
|
||||
|
||||
- [Overview](#-overview)
|
||||
- [Quick Start](#-quick-start)
|
||||
- [Large Batch Training Optimization](#large-batch-training-optimization)
|
||||
- [Table of contents](#table-of-contents)
|
||||
- [📚 Overview](#-overview)
|
||||
- [🚀 Quick Start](#-quick-start)
|
||||
|
||||
## 📚 Overview
|
||||
|
||||
|
|
|
@ -1,139 +1,56 @@
|
|||
# Sequence Parallelism with BERT
|
||||
# Sequence Parallelism
|
||||
|
||||
In this example, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate
|
||||
## Table of contents
|
||||
|
||||
- [Sequence Parallelism](#sequence-parallelism)
|
||||
- [Table of contents](#table-of-contents)
|
||||
- [📚 Overview](#-overview)
|
||||
- [🚀 Quick Start](#-quick-start)
|
||||
- [🏎 How to Train with Sequence Parallelism](#-how-to-train-with-sequence-parallelism)
|
||||
- [Step 1. Configure your parameters](#step-1-configure-your-parameters)
|
||||
- [Step 2. Invoke parallel training](#step-2-invoke-parallel-training)
|
||||
|
||||
## 📚 Overview
|
||||
|
||||
In this tutorial, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate
|
||||
activation along the sequence dimension. This method can achieve better memory efficiency and allows us to train with larger batch size and longer sequence length.
|
||||
|
||||
Paper: [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)
|
||||
|
||||
## 🚀Quick Start
|
||||
1. Run with the following command
|
||||
## 🚀 Quick Start
|
||||
|
||||
1. Install PyTorch
|
||||
|
||||
2. Install the dependencies.
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Run with the following command
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=$PWD
|
||||
colossalai run --nproc_per_node 4 train.py -s
|
||||
```
|
||||
2. The default config is sequence parallel size = 2, pipeline size = 1, let’s change pipeline size to be 2 and try it again.
|
||||
|
||||
|
||||
## How to Prepare WikiPedia Dataset
|
||||
|
||||
First, let's prepare the WikiPedia dataset from scratch. To generate a preprocessed dataset, we need four items:
|
||||
1. raw WikiPedia dataset
|
||||
2. wikipedia extractor (extract data from the raw dataset)
|
||||
3. vocabulary file
|
||||
4. preprocessing scripts (generate final data from extracted data)
|
||||
|
||||
For the preprocessing script, we thank Megatron-LM for providing a preprocessing script to generate the corpus file.
|
||||
|
||||
```python
|
||||
# download raw data
|
||||
mkdir data && cd ./data
|
||||
wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2
|
||||
|
||||
# install wiki extractor
|
||||
git clone https://github.com/FrankLeeeee/wikiextractor.git
|
||||
pip install ./wikiextractor
|
||||
|
||||
# extractmodule
|
||||
wikiextractor --json enwiki-latest-pages-articles.xml.bz2
|
||||
cat text/*/* > ./corpus.json
|
||||
cd ..
|
||||
|
||||
# download vocab file
|
||||
mkdir vocab && cd ./vocab
|
||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt
|
||||
cd ..
|
||||
|
||||
# preprocess some data
|
||||
git clone https://github.com/NVIDIA/Megatron-LM.git
|
||||
cd ./Megatron-LM
|
||||
python tools/preprocess_data.py \
|
||||
--input ../data/corpus.json \
|
||||
--output-prefix my-bert \
|
||||
--vocab ../vocab/bert-large-uncased-vocab.txt \
|
||||
--dataset-impl mmap \
|
||||
--tokenizer-type BertWordPieceLowerCase \
|
||||
--split-sentences \
|
||||
--workers 24
|
||||
# run with synthetic dataset
|
||||
colossalai run --nproc_per_node 4 train.py
|
||||
```
|
||||
|
||||
After running the preprocessing scripts, you will obtain two files:
|
||||
1. my-bert_text_sentence.bin
|
||||
2. my-bert_text_sentence.idx
|
||||
> The default config is sequence parallel size = 2, pipeline size = 1, let’s change pipeline size to be 2 and try it again.
|
||||
|
||||
If you happen to encouter `index out of range` problem when running Megatron's script,
|
||||
this is probably because that a sentence starts with a punctuation and cannot be tokenized. A work-around is to update `Encoder.encode` method with the code below:
|
||||
|
||||
```python
|
||||
class Encoder(object):
|
||||
def __init__(self, args):
|
||||
...
|
||||
|
||||
def initializer(self):
|
||||
...
|
||||
|
||||
def encode(self, json_line):
|
||||
data = json.loads(json_line)
|
||||
ids = {}
|
||||
for key in self.args.json_keys:
|
||||
text = data[key]
|
||||
doc_ids = []
|
||||
|
||||
# lsg: avoid sentences which start with a punctuation
|
||||
# as it cannot be tokenized by splitter
|
||||
if len(text) > 0 and text[0] in string.punctuation:
|
||||
text = text[1:]
|
||||
|
||||
for sentence in Encoder.splitter.tokenize(text):
|
||||
sentence_ids = Encoder.tokenizer.tokenize(sentence)
|
||||
if len(sentence_ids) > 0:
|
||||
doc_ids.append(sentence_ids)
|
||||
if len(doc_ids) > 0 and self.args.append_eod:
|
||||
doc_ids[-1].append(Encoder.tokenizer.eod)
|
||||
ids[key] = doc_ids
|
||||
return ids, len(json_line)
|
||||
```
|
||||
|
||||
## How to Train with Sequence Parallelism
|
||||
## 🏎 How to Train with Sequence Parallelism
|
||||
|
||||
We provided `train.py` for you to execute training. Before invoking the script, there are several
|
||||
steps to perform.
|
||||
|
||||
### Step 1. Set data path and vocab path
|
||||
|
||||
At the top of `config.py`, you can see two global variables `DATA_PATH` and `VOCAB_FILE_PATH`.
|
||||
|
||||
```python
|
||||
DATA_PATH = <data-path>
|
||||
VOCAB_FILE_PATH = <vocab-path>
|
||||
```
|
||||
|
||||
`DATA_PATH` refers to the path to the data file generated by Megatron's script. For example, in the section above, you should get two data files (my-bert_text_sentence.bin and my-bert_text_sentence.idx). You just need to `DATA_PATH` to the path to the bin file without the file extension.
|
||||
|
||||
For example, if your my-bert_text_sentence.bin is /home/Megatron-LM/my-bert_text_sentence.bin, then you should set
|
||||
|
||||
```python
|
||||
DATA_PATH = '/home/Megatron-LM/my-bert_text_sentence'
|
||||
```
|
||||
|
||||
The `VOCAB_FILE_PATH` refers to the path to the vocabulary downloaded when you prepare the dataset
|
||||
(e.g. bert-large-uncased-vocab.txt).
|
||||
|
||||
### Step 3. Make Dataset Helper
|
||||
|
||||
Build BERT dataset helper. Requirements are `CUDA`, `g++`, `pybind11` and `make`.
|
||||
|
||||
```python
|
||||
cd ./data/datasets
|
||||
make
|
||||
```
|
||||
|
||||
### Step 3. Configure your parameters
|
||||
### Step 1. Configure your parameters
|
||||
|
||||
In the `config.py` provided, a set of parameters are defined including training scheme, model, etc.
|
||||
You can also modify the ColossalAI setting. For example, if you wish to parallelize over the
|
||||
sequence dimension on 8 GPUs. You can change `size=4` to `size=8`. If you wish to use pipeline parallelism, you can set `pipeline=<num_of_pipeline_stages>`.
|
||||
|
||||
### Step 4. Invoke parallel training
|
||||
### Step 2. Invoke parallel training
|
||||
|
||||
Lastly, you can start training with sequence parallelism. How you invoke `train.py` depends on your
|
||||
machine setting.
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
DATA_PATH = ''
|
||||
VOCAB_FILE_PATH = ''
|
||||
|
||||
# hyper-parameters
|
||||
TRAIN_ITERS = 1000000
|
||||
DECAY_ITERS = 990000
|
||||
TRAIN_ITERS = 10
|
||||
DECAY_ITERS = 4
|
||||
WARMUP_FRACTION = 0.01
|
||||
GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU
|
||||
EVAL_ITERS = 10
|
||||
|
@ -13,12 +10,12 @@ EVAL_INTERVAL = 10
|
|||
LR = 0.0001
|
||||
MIN_LR = 1e-05
|
||||
WEIGHT_DECAY = 0.01
|
||||
SEQ_LENGTH = 512
|
||||
SEQ_LENGTH = 128
|
||||
|
||||
# BERT config
|
||||
DEPTH = 12
|
||||
NUM_ATTENTION_HEADS = 12
|
||||
HIDDEN_SIZE = 768
|
||||
DEPTH = 4
|
||||
NUM_ATTENTION_HEADS = 4
|
||||
HIDDEN_SIZE = 128
|
||||
|
||||
# model config
|
||||
ADD_BINARY_HEAD = False
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
colossalai >= 0.1.12
|
||||
torch >= 1.8.1
|
||||
colossalai
|
||||
torch
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
# run test
|
||||
colossalai run --nproc_per_node 4 train.py
|
|
@ -1,9 +1,8 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
from data import build_train_valid_test_data_iterators
|
||||
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel
|
||||
from data.tokenizer import get_padded_vocab_size, initialize_tokenizer
|
||||
from data.dummy_dataloader import DummyDataloader
|
||||
from loss_func.bert_loss import BertLoss
|
||||
from lr_scheduler import AnnealingLR
|
||||
from model.bert import BertForPretrain, build_pipeline_bert
|
||||
|
@ -36,7 +35,7 @@ def parse_args():
|
|||
|
||||
|
||||
def pipeline_data_process_func(stage_output, micro_batch_data):
|
||||
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
|
||||
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
data = (tokens, padding_mask, types, lm_labels)
|
||||
label = (loss_mask, sentence_order)
|
||||
|
@ -53,36 +52,15 @@ def main():
|
|||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# build dataloader
|
||||
if not args.synthetic:
|
||||
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase')
|
||||
VOCAB_SIZE = get_padded_vocab_size()
|
||||
trainloader, validloader, testloader = build_train_valid_test_data_iterators(
|
||||
train_iters=gpc.config.TRAIN_ITERS,
|
||||
global_batch_size=gpc.config.GLOBAL_BATCH_SIZE,
|
||||
eval_interval=gpc.config.EVAL_INTERVAL,
|
||||
eval_iters=gpc.config.EVAL_ITERS,
|
||||
data_prefix=[gpc.config.DATA_PATH],
|
||||
data_impl='mmap',
|
||||
splits_string='949,50,1',
|
||||
max_seq_length=gpc.config.SEQ_LENGTH,
|
||||
masked_lm_prob=0.15,
|
||||
short_seq_prob=0.1,
|
||||
seed=1234,
|
||||
skip_warmup=True,
|
||||
binary_head=False,
|
||||
)
|
||||
else:
|
||||
from data.dummy_dataloader import DummyDataloader
|
||||
|
||||
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
|
||||
VOCAB_SIZE = 30528
|
||||
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
seq_length=gpc.config.SEQ_LENGTH)
|
||||
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
seq_length=gpc.config.SEQ_LENGTH)
|
||||
# build synthetic dataloader
|
||||
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
|
||||
VOCAB_SIZE = 30528
|
||||
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
seq_length=gpc.config.SEQ_LENGTH)
|
||||
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
seq_length=gpc.config.SEQ_LENGTH)
|
||||
|
||||
logger.info("Dataloaders are built", ranks=[0])
|
||||
|
||||
|
|
Loading…
Reference in New Issue