mirror of https://github.com/hpcaitech/ColossalAI
[tutorial] added synthetic data for sequence parallel (#1927)
* [tutorial] added synthetic data for sequence parallel * polish codepull/1943/head
parent
abf4c27f6a
commit
807cbdb87d
|
@ -133,7 +133,7 @@ machine setting.
|
||||||
start your script. A sample command is like below:
|
start your script. A sample command is like below:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m torch.distributed.launch --nproc_per_node <num_gpus_on_this_machine> --master_addr localhost --master_port 29500 train.py
|
colossalai run --nproc_per_node <num_gpus_on_this_machine> --master_addr localhost --master_port 29500 train.py
|
||||||
```
|
```
|
||||||
|
|
||||||
- If you are using multiple machines with multiple GPUs, we suggest that you refer to `colossalai
|
- If you are using multiple machines with multiple GPUs, we suggest that you refer to `colossalai
|
||||||
|
|
|
@ -31,10 +31,8 @@ SEED = 1234
|
||||||
NUM_MICRO_BATCHES = 4
|
NUM_MICRO_BATCHES = 4
|
||||||
|
|
||||||
# colossalai config
|
# colossalai config
|
||||||
parallel = dict(pipeline=1, tensor=dict(size=4, mode='sequence'))
|
parallel = dict(pipeline=1, tensor=dict(size=2, mode='sequence'))
|
||||||
|
|
||||||
fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True)
|
fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True)
|
||||||
|
|
||||||
clip_grad_norm = 1.0
|
|
||||||
|
|
||||||
gradient_handler = [dict(type='SequenceParallelGradientHandler')]
|
gradient_handler = [dict(type='SequenceParallelGradientHandler')]
|
||||||
|
|
|
@ -14,19 +14,30 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""BERT Style dataset."""
|
"""BERT Style dataset."""
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from ..tokenizer import get_tokenizer
|
|
||||||
from .dataset_utils import (get_a_and_b_segments, truncate_segments, create_tokens_and_tokentypes,
|
|
||||||
create_masked_lm_predictions, pad_and_convert_to_numpy)
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
import time
|
from colossalai.core import global_context as gpc
|
||||||
import os
|
from colossalai.logging import get_dist_logger
|
||||||
from . import helpers
|
|
||||||
|
from ..tokenizer import get_tokenizer
|
||||||
|
from .dataset_utils import (
|
||||||
|
create_masked_lm_predictions,
|
||||||
|
create_tokens_and_tokentypes,
|
||||||
|
get_a_and_b_segments,
|
||||||
|
pad_and_convert_to_numpy,
|
||||||
|
truncate_segments,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import helpers
|
||||||
|
except:
|
||||||
|
print("helper is not built, ignore this message if you are using synthetic data.")
|
||||||
|
|
||||||
|
|
||||||
class BertDataset(Dataset):
|
class BertDataset(Dataset):
|
||||||
|
|
|
@ -1,21 +1,22 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from data import build_train_valid_test_data_iterators
|
||||||
|
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel
|
||||||
|
from data.tokenizer import get_padded_vocab_size, initialize_tokenizer
|
||||||
|
from loss_func.bert_loss import BertLoss
|
||||||
|
from lr_scheduler import AnnealingLR
|
||||||
|
from model.bert import BertForPretrain, build_pipeline_bert
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.amp import AMP_TYPE
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from data import build_train_valid_test_data_iterators
|
|
||||||
from data.tokenizer import initialize_tokenizer, get_padded_vocab_size
|
|
||||||
from data.bert_helper import get_batch_for_sequence_parallel, SequenceParallelDataIterator
|
|
||||||
from colossalai.amp import AMP_TYPE
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from colossalai.utils import MultiTimer, is_using_pp
|
|
||||||
from model.bert import BertForPretrain
|
|
||||||
from lr_scheduler import AnnealingLR
|
|
||||||
from loss_func.bert_loss import BertLoss
|
|
||||||
import torch
|
|
||||||
from colossalai.engine.schedule import PipelineSchedule
|
from colossalai.engine.schedule import PipelineSchedule
|
||||||
from colossalai.amp import AMP_TYPE
|
|
||||||
from colossalai.nn.optimizer import FusedAdam
|
|
||||||
from colossalai.kernel import LayerNorm
|
from colossalai.kernel import LayerNorm
|
||||||
from model.bert import build_pipeline_bert
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.nn.optimizer import FusedAdam
|
||||||
|
from colossalai.utils import MultiTimer, is_using_pp
|
||||||
|
|
||||||
|
|
||||||
def process_batch_data(batch_data):
|
def process_batch_data(batch_data):
|
||||||
|
@ -28,30 +29,49 @@ def process_batch_data(batch_data):
|
||||||
return data, label
|
return data, label
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# initialize
|
# initialize
|
||||||
|
args = parse_args()
|
||||||
colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl')
|
colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl')
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase')
|
if not args.synthetic:
|
||||||
VOCAB_SIZE = get_padded_vocab_size()
|
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase')
|
||||||
trainloader, validloader, testloader = build_train_valid_test_data_iterators(
|
VOCAB_SIZE = get_padded_vocab_size()
|
||||||
train_iters=gpc.config.TRAIN_ITERS,
|
trainloader, validloader, testloader = build_train_valid_test_data_iterators(
|
||||||
global_batch_size=gpc.config.GLOBAL_BATCH_SIZE,
|
train_iters=gpc.config.TRAIN_ITERS,
|
||||||
eval_interval=gpc.config.EVAL_INTERVAL,
|
global_batch_size=gpc.config.GLOBAL_BATCH_SIZE,
|
||||||
eval_iters=gpc.config.EVAL_ITERS,
|
eval_interval=gpc.config.EVAL_INTERVAL,
|
||||||
data_prefix=[gpc.config.DATA_PATH],
|
eval_iters=gpc.config.EVAL_ITERS,
|
||||||
data_impl='mmap',
|
data_prefix=[gpc.config.DATA_PATH],
|
||||||
splits_string='949,50,1',
|
data_impl='mmap',
|
||||||
max_seq_length=gpc.config.SEQ_LENGTH,
|
splits_string='949,50,1',
|
||||||
masked_lm_prob=0.15,
|
max_seq_length=gpc.config.SEQ_LENGTH,
|
||||||
short_seq_prob=0.1,
|
masked_lm_prob=0.15,
|
||||||
seed=1234,
|
short_seq_prob=0.1,
|
||||||
skip_warmup=True,
|
seed=1234,
|
||||||
binary_head=False,
|
skip_warmup=True,
|
||||||
)
|
binary_head=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from data.dummy_dataloader import DummyDataloader
|
||||||
|
|
||||||
|
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
VOCAB_SIZE = 30528
|
||||||
|
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||||
|
vocab_size=VOCAB_SIZE,
|
||||||
|
seq_length=gpc.config.SEQ_LENGTH)
|
||||||
|
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||||
|
vocab_size=VOCAB_SIZE,
|
||||||
|
seq_length=gpc.config.SEQ_LENGTH)
|
||||||
|
|
||||||
logger.info("Dataloaders are built", ranks=[0])
|
logger.info("Dataloaders are built", ranks=[0])
|
||||||
|
|
||||||
|
@ -121,11 +141,7 @@ def main():
|
||||||
logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps")
|
logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps")
|
||||||
|
|
||||||
# # init
|
# # init
|
||||||
engine, *dummy = colossalai.initialize(
|
engine, *dummy = colossalai.initialize(model, optimizer, criterion, verbose=True)
|
||||||
model,
|
|
||||||
optimizer,
|
|
||||||
criterion,
|
|
||||||
)
|
|
||||||
|
|
||||||
# build timer
|
# build timer
|
||||||
timer = MultiTimer()
|
timer = MultiTimer()
|
||||||
|
@ -140,6 +156,8 @@ def main():
|
||||||
train_data_iter = SequenceParallelDataIterator(trainloader)
|
train_data_iter = SequenceParallelDataIterator(trainloader)
|
||||||
valid_data_iter = SequenceParallelDataIterator(validloader)
|
valid_data_iter = SequenceParallelDataIterator(validloader)
|
||||||
|
|
||||||
|
logger.info("start training")
|
||||||
|
|
||||||
for step in range(1, gpc.config.TRAIN_ITERS + 1):
|
for step in range(1, gpc.config.TRAIN_ITERS + 1):
|
||||||
timer.start('train-iterations')
|
timer.start('train-iterations')
|
||||||
engine.train()
|
engine.train()
|
||||||
|
|
Loading…
Reference in New Issue