diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md index d1e307e05..e0e1dc5c1 100644 --- a/examples/language/gpt/README.md +++ b/examples/language/gpt/README.md @@ -1,14 +1,15 @@ ## Overview -This example shows how to use ColossalAI to run huggingface GPT training in distributed manners. +This example shows how to use Colossal-AI to run huggingface GPT training in distributed manners. ## GPT -We use the huggingface transformers GPT2 model. The input data is randonly generated. +We use the GPT2 model from huggingface transformers. The input data is randonly generated. ## Our Modifications -We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. +The `train_gpt_demo.py` provides three distributed plans, i.e. Colossal-AI, PyTorch DDP and ZeRO. +The Colossal-AI leverages Tensor Parallel and Gemini. ## Quick Start -You can launch training by using the following bash script +You can launch training by using the following bash script. ```bash pip install -r requirements.txt diff --git a/examples/language/gpt/run.sh b/examples/language/gpt/run.sh index 1ff2a4eed..6a4b5ce14 100644 --- a/examples/language/gpt/run.sh +++ b/examples/language/gpt/run.sh @@ -1 +1,10 @@ -env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=4 train_gpt_demo.py --tp_degree=2 --placement='cpu' 2>&1 | tee run.log +# distplan in ["colossalai", "zero", "ddp"] +export DISTPAN="colossalai" + +# The following options only valid when DISTPAN="colossalai" +export TPDEGREE=2 +export GPUNUM=4 +export PLACEMENT='cpu' +export USE_SHARD_INIT=False + +env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log diff --git a/examples/language/gpt/train_gpt_demo.py b/examples/language/gpt/train_gpt_demo.py index cdf7c41b2..99de40e5f 100644 --- a/examples/language/gpt/train_gpt_demo.py +++ b/examples/language/gpt/train_gpt_demo.py @@ -5,12 +5,13 @@ import psutil import torch import torch.nn as nn from packaging import version +from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam from colossalai.nn.parallel import ZeroDDP -from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ZeroOptimizer @@ -19,17 +20,30 @@ from transformers import GPT2Config, GPT2LMHeadModel def parse_args(): parser = colossalai.get_default_parser() + parser.add_argument( + "--distplan", + type=str, + default='colossalai', + help="The distributed plan [colossalai, ddp, zero].", + ) parser.add_argument( "--tp_degree", type=int, default=1, - help="Tensor Parallelism Degree.", + help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", ) parser.add_argument( "--placement", type=str, default='cpu', - help="Placement Policy for Gemini.", + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--shardinit", + type=bool, + default=False, + help= + "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", ) args = parser.parse_args() return args @@ -38,8 +52,6 @@ def parse_args(): ## Parameter Sharding Strategies for Tensor Parallelism def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - if param.process_group.tp_world_size() == 1: - param.set_process_group(pg) param.set_tensor_spec(*spec) @@ -136,21 +148,30 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): """ for mn, module in model.named_modules(): for pn, param in module.named_parameters(recurse=False): - # set process group for all parameters - param.set_process_group(pg) - + # NOTE() a param maybe shared by tow modules + if hasattr(param, 'visited'): + continue + param.set_dist_spec(ReplicaSpec()) if 'mlp.c_fc' in mn: if 'weight' in pn or 'bias' in pn: split_param_col_tp1d(param, pg) # colmn slice # keep the shape of the output from c_fc param.compute_spec.set_output_replicate(False) + else: + param.set_dist_spec(ReplicaSpec()) elif 'mlp.c_proj' in mn: if 'weight' in pn: split_param_row_tp1d(param, pg) # row slice + else: + param.set_dist_spec(ReplicaSpec()) elif 'wte' in mn or 'wpe' in mn: split_param_col_tp1d(param, pg) # colmn slice elif 'c_attn' in mn or 'c_proj' in mn: split_param_col_tp1d(param, pg) # colmn slice + else: + param.set_dist_spec(ReplicaSpec()) + + param.visited = True # Gemini + ZeRO DDP @@ -188,32 +209,49 @@ def main(): disable_existing_loggers() colossalai.launch_from_torch(config={}) - pg = ProcessGroup(tp_degree=args.tp_degree) - logger = get_dist_logger() - logger.info(get_mem_info(), ranks=[0]) - - # build GPT model - with ColoInitContext(device=get_current_device()): - model = gpt2_medium(checkpoint=True) - - numel = sum([p.numel() for p in model.parameters()]) - logger.info(f'Model numel: {numel}', ranks=[0]) - get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) - - # Tensor Parallelism (TP) - tensor_parallelize(model, pg) - # Gemini + ZeRO DP, Note it must be used after TP - model = gemini_zero_dpp(model, pg, args.placement) - logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + logger.info(f"using dist plan {args.distplan}", ranks=[0]) # build criterion criterion = GPTLMLoss() - # build optimizer - optimizer = HybridAdam(model.parameters(), lr=1e-3) - optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) - logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + torch.manual_seed(123) + if args.distplan == "colossalai": + # all param must use the same process group. + default_pg = ProcessGroup(tp_degree=args.tp_degree) + default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None + + # build GPT model + with ColoInitContext(device='cuda', default_dist_spec=default_dist_spec, default_pg=default_pg): + model = gpt2_medium(checkpoint=True) + + pg = default_pg + # Tensor Parallelism (TP) + tensor_parallelize(model, pg) + # Gemini + ZeRO DP, Note it must be used after TP + model = gemini_zero_dpp(model, pg, args.placement) + + # build optimizer + optimizer = HybridAdam(model.parameters(), lr=1e-3) + optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) + logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + + elif args.distplan == "ddp": + model = gpt2_medium(checkpoint=True).cuda() + ddp_model = DDP(model) + optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01) + + elif args.distplan == "zero": + from torch.distributed.optim import ZeroRedundancyOptimizer + model = gpt2_medium(checkpoint=True).cuda() + ddp_model = DDP(model) + optimizer = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01) + else: + raise TypeError(f"{args.distplan} is error") + + numel = sum([p.numel() for p in model.parameters()]) + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) torch.cuda.synchronize() model.train() @@ -225,7 +263,11 @@ def main(): outputs = model(input_ids, attn_mask) loss = criterion(outputs, input_ids) logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0]) - optimizer.backward(loss) + if args.distplan == "colossalai": + optimizer.backward(loss) + elif args.distplan in ["ddp", "zero"]: + loss.backward() + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0]) optimizer.step() logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])