[tutorial] added synthetic data for sequence parallel (#1927)

* [tutorial] added synthetic data for sequence parallel

* polish code
pull/1943/head
Frank Lee 2022-11-13 03:24:02 +08:00 committed by GitHub
parent abf4c27f6a
commit 807cbdb87d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 47 deletions

View File

@ -133,7 +133,7 @@ machine setting.
start your script. A sample command is like below:
```bash
python -m torch.distributed.launch --nproc_per_node <num_gpus_on_this_machine> --master_addr localhost --master_port 29500 train.py
colossalai run --nproc_per_node <num_gpus_on_this_machine> --master_addr localhost --master_port 29500 train.py
```
- If you are using multiple machines with multiple GPUs, we suggest that you refer to `colossalai

View File

@ -31,10 +31,8 @@ SEED = 1234
NUM_MICRO_BATCHES = 4
# colossalai config
parallel = dict(pipeline=1, tensor=dict(size=4, mode='sequence'))
parallel = dict(pipeline=1, tensor=dict(size=2, mode='sequence'))
fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True)
clip_grad_norm = 1.0
gradient_handler = [dict(type='SequenceParallelGradientHandler')]

View File

@ -14,19 +14,30 @@
# limitations under the License.
"""BERT Style dataset."""
from colossalai.logging import get_dist_logger
import os
import time
import numpy as np
import torch
from torch.utils.data import Dataset
from ..tokenizer import get_tokenizer
from .dataset_utils import (get_a_and_b_segments, truncate_segments, create_tokens_and_tokentypes,
create_masked_lm_predictions, pad_and_convert_to_numpy)
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
import time
import os
from . import helpers
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from ..tokenizer import get_tokenizer
from .dataset_utils import (
create_masked_lm_predictions,
create_tokens_and_tokentypes,
get_a_and_b_segments,
pad_and_convert_to_numpy,
truncate_segments,
)
try:
from . import helpers
except:
print("helper is not built, ignore this message if you are using synthetic data.")
class BertDataset(Dataset):

View File

@ -1,21 +1,22 @@
import argparse
import torch
from data import build_train_valid_test_data_iterators
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel
from data.tokenizer import get_padded_vocab_size, initialize_tokenizer
from loss_func.bert_loss import BertLoss
from lr_scheduler import AnnealingLR
from model.bert import BertForPretrain, build_pipeline_bert
import colossalai
from colossalai.amp import AMP_TYPE
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from data import build_train_valid_test_data_iterators
from data.tokenizer import initialize_tokenizer, get_padded_vocab_size
from data.bert_helper import get_batch_for_sequence_parallel, SequenceParallelDataIterator
from colossalai.amp import AMP_TYPE
from colossalai.logging import get_dist_logger
from colossalai.utils import MultiTimer, is_using_pp
from model.bert import BertForPretrain
from lr_scheduler import AnnealingLR
from loss_func.bert_loss import BertLoss
import torch
from colossalai.engine.schedule import PipelineSchedule
from colossalai.amp import AMP_TYPE
from colossalai.nn.optimizer import FusedAdam
from colossalai.kernel import LayerNorm
from model.bert import build_pipeline_bert
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import FusedAdam
from colossalai.utils import MultiTimer, is_using_pp
def process_batch_data(batch_data):
@ -28,30 +29,49 @@ def process_batch_data(batch_data):
return data, label
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
return parser.parse_args()
def main():
# initialize
args = parse_args()
colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl')
logger = get_dist_logger()
# build dataloader
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase')
VOCAB_SIZE = get_padded_vocab_size()
trainloader, validloader, testloader = build_train_valid_test_data_iterators(
train_iters=gpc.config.TRAIN_ITERS,
global_batch_size=gpc.config.GLOBAL_BATCH_SIZE,
eval_interval=gpc.config.EVAL_INTERVAL,
eval_iters=gpc.config.EVAL_ITERS,
data_prefix=[gpc.config.DATA_PATH],
data_impl='mmap',
splits_string='949,50,1',
max_seq_length=gpc.config.SEQ_LENGTH,
masked_lm_prob=0.15,
short_seq_prob=0.1,
seed=1234,
skip_warmup=True,
binary_head=False,
)
if not args.synthetic:
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase')
VOCAB_SIZE = get_padded_vocab_size()
trainloader, validloader, testloader = build_train_valid_test_data_iterators(
train_iters=gpc.config.TRAIN_ITERS,
global_batch_size=gpc.config.GLOBAL_BATCH_SIZE,
eval_interval=gpc.config.EVAL_INTERVAL,
eval_iters=gpc.config.EVAL_ITERS,
data_prefix=[gpc.config.DATA_PATH],
data_impl='mmap',
splits_string='949,50,1',
max_seq_length=gpc.config.SEQ_LENGTH,
masked_lm_prob=0.15,
short_seq_prob=0.1,
seed=1234,
skip_warmup=True,
binary_head=False,
)
else:
from data.dummy_dataloader import DummyDataloader
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
VOCAB_SIZE = 30528
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
vocab_size=VOCAB_SIZE,
seq_length=gpc.config.SEQ_LENGTH)
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
vocab_size=VOCAB_SIZE,
seq_length=gpc.config.SEQ_LENGTH)
logger.info("Dataloaders are built", ranks=[0])
@ -121,11 +141,7 @@ def main():
logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps")
# # init
engine, *dummy = colossalai.initialize(
model,
optimizer,
criterion,
)
engine, *dummy = colossalai.initialize(model, optimizer, criterion, verbose=True)
# build timer
timer = MultiTimer()
@ -140,6 +156,8 @@ def main():
train_data_iter = SequenceParallelDataIterator(trainloader)
valid_data_iter = SequenceParallelDataIterator(validloader)
logger.info("start training")
for step in range(1, gpc.config.TRAIN_ITERS + 1):
timer.start('train-iterations')
engine.train()