[example] enhance GPT demo (#1959)

* [example] enhence GPT demo

* Update README.md

Co-authored-by: binmakeswell <binmakeswell@gmail.com>
pull/1960/head^2
Jiarui Fang 2 years ago committed by GitHub
parent acba142929
commit 60abd86d6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,14 +1,15 @@
## Overview ## Overview
This example shows how to use ColossalAI to run huggingface GPT training in distributed manners. This example shows how to use Colossal-AI to run huggingface GPT training in distributed manners.
## GPT ## GPT
We use the huggingface transformers GPT2 model. The input data is randonly generated. We use the GPT2 model from huggingface transformers. The input data is randonly generated.
## Our Modifications ## Our Modifications
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. The `train_gpt_demo.py` provides three distributed plans, i.e. Colossal-AI, PyTorch DDP and ZeRO.
The Colossal-AI leverages Tensor Parallel and Gemini.
## Quick Start ## Quick Start
You can launch training by using the following bash script You can launch training by using the following bash script.
```bash ```bash
pip install -r requirements.txt pip install -r requirements.txt

@ -1 +1,10 @@
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=4 train_gpt_demo.py --tp_degree=2 --placement='cpu' 2>&1 | tee run.log # distplan in ["colossalai", "zero", "ddp"]
export DISTPAN="colossalai"
# The following options only valid when DISTPAN="colossalai"
export TPDEGREE=2
export GPUNUM=4
export PLACEMENT='cpu'
export USE_SHARD_INIT=False
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log

@ -5,12 +5,13 @@ import psutil
import torch import torch
import torch.nn as nn import torch.nn as nn
from packaging import version from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
@ -19,17 +20,30 @@ from transformers import GPT2Config, GPT2LMHeadModel
def parse_args(): def parse_args():
parser = colossalai.get_default_parser() parser = colossalai.get_default_parser()
parser.add_argument(
"--distplan",
type=str,
default='colossalai',
help="The distributed plan [colossalai, ddp, zero].",
)
parser.add_argument( parser.add_argument(
"--tp_degree", "--tp_degree",
type=int, type=int,
default=1, default=1,
help="Tensor Parallelism Degree.", help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
) )
parser.add_argument( parser.add_argument(
"--placement", "--placement",
type=str, type=str,
default='cpu', default='cpu',
help="Placement Policy for Gemini.", help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
type=bool,
default=False,
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
) )
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -38,8 +52,6 @@ def parse_args():
## Parameter Sharding Strategies for Tensor Parallelism ## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
if param.process_group.tp_world_size() == 1:
param.set_process_group(pg)
param.set_tensor_spec(*spec) param.set_tensor_spec(*spec)
@ -136,21 +148,30 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
""" """
for mn, module in model.named_modules(): for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False): for pn, param in module.named_parameters(recurse=False):
# set process group for all parameters # NOTE() a param maybe shared by tow modules
param.set_process_group(pg) if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'mlp.c_fc' in mn: if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn: if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice split_param_col_tp1d(param, pg) # colmn slice
# keep the shape of the output from c_fc # keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False) param.compute_spec.set_output_replicate(False)
else:
param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn: elif 'mlp.c_proj' in mn:
if 'weight' in pn: if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn: elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # colmn slice split_param_col_tp1d(param, pg) # colmn slice
elif 'c_attn' in mn or 'c_proj' in mn: elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # colmn slice split_param_col_tp1d(param, pg) # colmn slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
# Gemini + ZeRO DDP # Gemini + ZeRO DDP
@ -188,32 +209,49 @@ def main():
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch(config={}) colossalai.launch_from_torch(config={})
pg = ProcessGroup(tp_degree=args.tp_degree)
logger = get_dist_logger() logger = get_dist_logger()
logger.info(get_mem_info(), ranks=[0]) logger.info(f"using dist plan {args.distplan}", ranks=[0])
# build GPT model
with ColoInitContext(device=get_current_device()):
model = gpt2_medium(checkpoint=True)
numel = sum([p.numel() for p in model.parameters()])
logger.info(f'Model numel: {numel}', ranks=[0])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
# build criterion # build criterion
criterion = GPTLMLoss() criterion = GPTLMLoss()
# build optimizer torch.manual_seed(123)
optimizer = HybridAdam(model.parameters(), lr=1e-3) if args.distplan == "colossalai":
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) # all param must use the same process group.
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
# build GPT model
with ColoInitContext(device='cuda', default_dist_spec=default_dist_spec, default_pg=default_pg):
model = gpt2_medium(checkpoint=True)
pg = default_pg
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement)
# build optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
elif args.distplan == "ddp":
model = gpt2_medium(checkpoint=True).cuda()
ddp_model = DDP(model)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
elif args.distplan == "zero":
from torch.distributed.optim import ZeroRedundancyOptimizer
model = gpt2_medium(checkpoint=True).cuda()
ddp_model = DDP(model)
optimizer = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
else:
raise TypeError(f"{args.distplan} is error")
numel = sum([p.numel() for p in model.parameters()])
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
torch.cuda.synchronize() torch.cuda.synchronize()
model.train() model.train()
@ -225,7 +263,11 @@ def main():
outputs = model(input_ids, attn_mask) outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids) loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0]) logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
optimizer.backward(loss) if args.distplan == "colossalai":
optimizer.backward(loss)
elif args.distplan in ["ddp", "zero"]:
loss.backward()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0]) logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
optimizer.step() optimizer.step()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])

Loading…
Cancel
Save