Browse Source

[example] enhance GPT demo (#1959)

* [example] enhence GPT demo

* Update README.md

Co-authored-by: binmakeswell <binmakeswell@gmail.com>
pull/1960/head^2
Jiarui Fang 2 years ago committed by GitHub
parent
commit
60abd86d6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      examples/language/gpt/README.md
  2. 11
      examples/language/gpt/run.sh
  3. 102
      examples/language/gpt/train_gpt_demo.py

9
examples/language/gpt/README.md

@ -1,14 +1,15 @@
## Overview
This example shows how to use ColossalAI to run huggingface GPT training in distributed manners.
This example shows how to use Colossal-AI to run huggingface GPT training in distributed manners.
## GPT
We use the huggingface transformers GPT2 model. The input data is randonly generated.
We use the GPT2 model from huggingface transformers. The input data is randonly generated.
## Our Modifications
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
The `train_gpt_demo.py` provides three distributed plans, i.e. Colossal-AI, PyTorch DDP and ZeRO.
The Colossal-AI leverages Tensor Parallel and Gemini.
## Quick Start
You can launch training by using the following bash script
You can launch training by using the following bash script.
```bash
pip install -r requirements.txt

11
examples/language/gpt/run.sh

@ -1 +1,10 @@
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=4 train_gpt_demo.py --tp_degree=2 --placement='cpu' 2>&1 | tee run.log
# distplan in ["colossalai", "zero", "ddp"]
export DISTPAN="colossalai"
# The following options only valid when DISTPAN="colossalai"
export TPDEGREE=2
export GPUNUM=4
export PLACEMENT='cpu'
export USE_SHARD_INIT=False
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log

102
examples/language/gpt/train_gpt_demo.py

@ -5,12 +5,13 @@ import psutil
import torch
import torch.nn as nn
from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
@ -19,17 +20,30 @@ from transformers import GPT2Config, GPT2LMHeadModel
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--distplan",
type=str,
default='colossalai',
help="The distributed plan [colossalai, ddp, zero].",
)
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree.",
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini.",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
type=bool,
default=False,
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
args = parser.parse_args()
return args
@ -38,8 +52,6 @@ def parse_args():
## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
if param.process_group.tp_world_size() == 1:
param.set_process_group(pg)
param.set_tensor_spec(*spec)
@ -136,21 +148,30 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
# set process group for all parameters
param.set_process_group(pg)
# NOTE() a param maybe shared by tow modules
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
else:
param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # colmn slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
# Gemini + ZeRO DDP
@ -188,32 +209,49 @@ def main():
disable_existing_loggers()
colossalai.launch_from_torch(config={})
pg = ProcessGroup(tp_degree=args.tp_degree)
logger = get_dist_logger()
logger.info(get_mem_info(), ranks=[0])
# build GPT model
with ColoInitContext(device=get_current_device()):
model = gpt2_medium(checkpoint=True)
numel = sum([p.numel() for p in model.parameters()])
logger.info(f'Model numel: {numel}', ranks=[0])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
logger.info(f"using dist plan {args.distplan}", ranks=[0])
# build criterion
criterion = GPTLMLoss()
# build optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
torch.manual_seed(123)
if args.distplan == "colossalai":
# all param must use the same process group.
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
# build GPT model
with ColoInitContext(device='cuda', default_dist_spec=default_dist_spec, default_pg=default_pg):
model = gpt2_medium(checkpoint=True)
pg = default_pg
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement)
# build optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
elif args.distplan == "ddp":
model = gpt2_medium(checkpoint=True).cuda()
ddp_model = DDP(model)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
elif args.distplan == "zero":
from torch.distributed.optim import ZeroRedundancyOptimizer
model = gpt2_medium(checkpoint=True).cuda()
ddp_model = DDP(model)
optimizer = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
else:
raise TypeError(f"{args.distplan} is error")
numel = sum([p.numel() for p in model.parameters()])
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
torch.cuda.synchronize()
model.train()
@ -225,7 +263,11 @@ def main():
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
optimizer.backward(loss)
if args.distplan == "colossalai":
optimizer.backward(loss)
elif args.distplan in ["ddp", "zero"]:
loss.backward()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
optimizer.step()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])

Loading…
Cancel
Save