[example] integrate seq-parallel tutorial with CI (#2463)

pull/2476/head
Frank Lee 2023-01-13 14:40:05 +08:00 committed by GitHub
parent 8e85d2440a
commit 8b7495dd54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 72 additions and 170 deletions

View File

@ -114,6 +114,13 @@ class FusedScaleMaskSoftmax(nn.Module):
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale self.scale = scale
try:
from colossalai._C import scaled_masked_softmax
except ImportError:
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
self.scaled_masked_softmax = scaled_masked_softmax
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled" assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
@ -178,11 +185,5 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs return probs
@staticmethod def get_batch_per_block(self, sq, sk, b, np):
def get_batch_per_block(sq, sk, b, np): return self.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
try:
import colossalai._C.scaled_masked_softmax
except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)

View File

@ -1,9 +1,11 @@
# Comparison of Large Batch Training Optimization # Large Batch Training Optimization
## Table of contents ## Table of contents
- [Overview](#-overview) - [Large Batch Training Optimization](#large-batch-training-optimization)
- [Quick Start](#-quick-start) - [Table of contents](#table-of-contents)
- [📚 Overview](#-overview)
- [🚀 Quick Start](#-quick-start)
## 📚 Overview ## 📚 Overview

View File

@ -1,139 +1,56 @@
# Sequence Parallelism with BERT # Sequence Parallelism
In this example, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate ## Table of contents
- [Sequence Parallelism](#sequence-parallelism)
- [Table of contents](#table-of-contents)
- [📚 Overview](#-overview)
- [🚀 Quick Start](#-quick-start)
- [🏎 How to Train with Sequence Parallelism](#-how-to-train-with-sequence-parallelism)
- [Step 1. Configure your parameters](#step-1-configure-your-parameters)
- [Step 2. Invoke parallel training](#step-2-invoke-parallel-training)
## 📚 Overview
In this tutorial, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate
activation along the sequence dimension. This method can achieve better memory efficiency and allows us to train with larger batch size and longer sequence length. activation along the sequence dimension. This method can achieve better memory efficiency and allows us to train with larger batch size and longer sequence length.
Paper: [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120) Paper: [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)
## 🚀Quick Start ## 🚀 Quick Start
1. Run with the following command
1. Install PyTorch
2. Install the dependencies.
```bash
pip install -r requirements.txt
```
3. Run with the following command
```bash ```bash
export PYTHONPATH=$PWD export PYTHONPATH=$PWD
colossalai run --nproc_per_node 4 train.py -s
```
2. The default config is sequence parallel size = 2, pipeline size = 1, lets change pipeline size to be 2 and try it again.
# run with synthetic dataset
## How to Prepare WikiPedia Dataset colossalai run --nproc_per_node 4 train.py
First, let's prepare the WikiPedia dataset from scratch. To generate a preprocessed dataset, we need four items:
1. raw WikiPedia dataset
2. wikipedia extractor (extract data from the raw dataset)
3. vocabulary file
4. preprocessing scripts (generate final data from extracted data)
For the preprocessing script, we thank Megatron-LM for providing a preprocessing script to generate the corpus file.
```python
# download raw data
mkdir data && cd ./data
wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2
# install wiki extractor
git clone https://github.com/FrankLeeeee/wikiextractor.git
pip install ./wikiextractor
# extractmodule
wikiextractor --json enwiki-latest-pages-articles.xml.bz2
cat text/*/* > ./corpus.json
cd ..
# download vocab file
mkdir vocab && cd ./vocab
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt
cd ..
# preprocess some data
git clone https://github.com/NVIDIA/Megatron-LM.git
cd ./Megatron-LM
python tools/preprocess_data.py \
--input ../data/corpus.json \
--output-prefix my-bert \
--vocab ../vocab/bert-large-uncased-vocab.txt \
--dataset-impl mmap \
--tokenizer-type BertWordPieceLowerCase \
--split-sentences \
--workers 24
``` ```
After running the preprocessing scripts, you will obtain two files: > The default config is sequence parallel size = 2, pipeline size = 1, lets change pipeline size to be 2 and try it again.
1. my-bert_text_sentence.bin
2. my-bert_text_sentence.idx
If you happen to encouter `index out of range` problem when running Megatron's script,
this is probably because that a sentence starts with a punctuation and cannot be tokenized. A work-around is to update `Encoder.encode` method with the code below:
```python ## 🏎 How to Train with Sequence Parallelism
class Encoder(object):
def __init__(self, args):
...
def initializer(self):
...
def encode(self, json_line):
data = json.loads(json_line)
ids = {}
for key in self.args.json_keys:
text = data[key]
doc_ids = []
# lsg: avoid sentences which start with a punctuation
# as it cannot be tokenized by splitter
if len(text) > 0 and text[0] in string.punctuation:
text = text[1:]
for sentence in Encoder.splitter.tokenize(text):
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
ids[key] = doc_ids
return ids, len(json_line)
```
## How to Train with Sequence Parallelism
We provided `train.py` for you to execute training. Before invoking the script, there are several We provided `train.py` for you to execute training. Before invoking the script, there are several
steps to perform. steps to perform.
### Step 1. Set data path and vocab path ### Step 1. Configure your parameters
At the top of `config.py`, you can see two global variables `DATA_PATH` and `VOCAB_FILE_PATH`.
```python
DATA_PATH = <data-path>
VOCAB_FILE_PATH = <vocab-path>
```
`DATA_PATH` refers to the path to the data file generated by Megatron's script. For example, in the section above, you should get two data files (my-bert_text_sentence.bin and my-bert_text_sentence.idx). You just need to `DATA_PATH` to the path to the bin file without the file extension.
For example, if your my-bert_text_sentence.bin is /home/Megatron-LM/my-bert_text_sentence.bin, then you should set
```python
DATA_PATH = '/home/Megatron-LM/my-bert_text_sentence'
```
The `VOCAB_FILE_PATH` refers to the path to the vocabulary downloaded when you prepare the dataset
(e.g. bert-large-uncased-vocab.txt).
### Step 3. Make Dataset Helper
Build BERT dataset helper. Requirements are `CUDA`, `g++`, `pybind11` and `make`.
```python
cd ./data/datasets
make
```
### Step 3. Configure your parameters
In the `config.py` provided, a set of parameters are defined including training scheme, model, etc. In the `config.py` provided, a set of parameters are defined including training scheme, model, etc.
You can also modify the ColossalAI setting. For example, if you wish to parallelize over the You can also modify the ColossalAI setting. For example, if you wish to parallelize over the
sequence dimension on 8 GPUs. You can change `size=4` to `size=8`. If you wish to use pipeline parallelism, you can set `pipeline=<num_of_pipeline_stages>`. sequence dimension on 8 GPUs. You can change `size=4` to `size=8`. If you wish to use pipeline parallelism, you can set `pipeline=<num_of_pipeline_stages>`.
### Step 4. Invoke parallel training ### Step 2. Invoke parallel training
Lastly, you can start training with sequence parallelism. How you invoke `train.py` depends on your Lastly, you can start training with sequence parallelism. How you invoke `train.py` depends on your
machine setting. machine setting.

View File

@ -1,11 +1,8 @@
from colossalai.amp import AMP_TYPE from colossalai.amp import AMP_TYPE
DATA_PATH = ''
VOCAB_FILE_PATH = ''
# hyper-parameters # hyper-parameters
TRAIN_ITERS = 1000000 TRAIN_ITERS = 10
DECAY_ITERS = 990000 DECAY_ITERS = 4
WARMUP_FRACTION = 0.01 WARMUP_FRACTION = 0.01
GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU
EVAL_ITERS = 10 EVAL_ITERS = 10
@ -13,12 +10,12 @@ EVAL_INTERVAL = 10
LR = 0.0001 LR = 0.0001
MIN_LR = 1e-05 MIN_LR = 1e-05
WEIGHT_DECAY = 0.01 WEIGHT_DECAY = 0.01
SEQ_LENGTH = 512 SEQ_LENGTH = 128
# BERT config # BERT config
DEPTH = 12 DEPTH = 4
NUM_ATTENTION_HEADS = 12 NUM_ATTENTION_HEADS = 4
HIDDEN_SIZE = 768 HIDDEN_SIZE = 128
# model config # model config
ADD_BINARY_HEAD = False ADD_BINARY_HEAD = False

View File

@ -1,2 +1,2 @@
colossalai >= 0.1.12 colossalai
torch >= 1.8.1 torch

View File

@ -0,0 +1,7 @@
#!/bin/bash
set -euxo pipefail
pip install -r requirements.txt
# run test
colossalai run --nproc_per_node 4 train.py

View File

@ -1,9 +1,8 @@
import argparse import argparse
import torch import torch
from data import build_train_valid_test_data_iterators
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel
from data.tokenizer import get_padded_vocab_size, initialize_tokenizer from data.dummy_dataloader import DummyDataloader
from loss_func.bert_loss import BertLoss from loss_func.bert_loss import BertLoss
from lr_scheduler import AnnealingLR from lr_scheduler import AnnealingLR
from model.bert import BertForPretrain, build_pipeline_bert from model.bert import BertForPretrain, build_pipeline_bert
@ -36,7 +35,7 @@ def parse_args():
def pipeline_data_process_func(stage_output, micro_batch_data): def pipeline_data_process_func(stage_output, micro_batch_data):
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
if gpc.is_first_rank(ParallelMode.PIPELINE): if gpc.is_first_rank(ParallelMode.PIPELINE):
data = (tokens, padding_mask, types, lm_labels) data = (tokens, padding_mask, types, lm_labels)
label = (loss_mask, sentence_order) label = (loss_mask, sentence_order)
@ -53,36 +52,15 @@ def main():
logger = get_dist_logger() logger = get_dist_logger()
# build dataloader # build synthetic dataloader
if not args.synthetic: BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase') VOCAB_SIZE = 30528
VOCAB_SIZE = get_padded_vocab_size() trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
trainloader, validloader, testloader = build_train_valid_test_data_iterators( vocab_size=VOCAB_SIZE,
train_iters=gpc.config.TRAIN_ITERS, seq_length=gpc.config.SEQ_LENGTH)
global_batch_size=gpc.config.GLOBAL_BATCH_SIZE, validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
eval_interval=gpc.config.EVAL_INTERVAL, vocab_size=VOCAB_SIZE,
eval_iters=gpc.config.EVAL_ITERS, seq_length=gpc.config.SEQ_LENGTH)
data_prefix=[gpc.config.DATA_PATH],
data_impl='mmap',
splits_string='949,50,1',
max_seq_length=gpc.config.SEQ_LENGTH,
masked_lm_prob=0.15,
short_seq_prob=0.1,
seed=1234,
skip_warmup=True,
binary_head=False,
)
else:
from data.dummy_dataloader import DummyDataloader
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
VOCAB_SIZE = 30528
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
vocab_size=VOCAB_SIZE,
seq_length=gpc.config.SEQ_LENGTH)
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
vocab_size=VOCAB_SIZE,
seq_length=gpc.config.SEQ_LENGTH)
logger.info("Dataloaders are built", ranks=[0]) logger.info("Dataloaders are built", ranks=[0])