mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* benchmark gpt2 * fix fix fix fix * [doc] fix typo in Colossal-LLaMA-2/README.md (#5247) * [workflow] fixed build CI (#5240) * [workflow] fixed build CI * polish * polish * polish * polish * polish * [ci] fixed booster test (#5251) * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed ddp test (#5254) * [ci] fixed ddp test * polish * fix typo in applications/ColossalEval/README.md (#5250) * [ci] fix shardformer tests. (#5255) * fix ci fix * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [doc] fix doc typo (#5256) * [doc] fix annotation display * [doc] fix llama2 doc * [hotfix]: add pp sanity check and fix mbs arg (#5268) * fix: fix misleading mbs arg * feat: add pp sanity check * fix: fix 1f1b sanity check * [workflow] fixed incomplete bash command (#5272) * [workflow] fixed oom tests (#5275) * [workflow] fixed oom tests * polish * polish * polish * [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276) * fix ci fix * fix test * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests * fix --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [shardformer] hybridparallelplugin support gradients accumulation. (#5246) * support gradients acc fix fix fix fix fix fix fix fix fix fix fix fix fix * fix fix * fix fix fix * [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230) * fix auto loading gpt2 tokenizer (#5279) * [doc] add llama2-13B disyplay (#5285) * Update README.md * fix 13b typo --------- Co-authored-by: binmakeswell <binmakeswell@gmail.com> * fix llama pretrain (#5287) * fix * fix * fix fix * fix fix fix * fix fix * benchmark gpt2 * fix fix fix fix * [workflow] fixed build CI (#5240) * [workflow] fixed build CI * polish * polish * polish * polish * polish * [ci] fixed booster test (#5251) * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed booster test * fix fix * fix fix fix * fix * fix fix fix fix fix * fix * Update shardformer.py --------- Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com> Co-authored-by: Desperado-Jia <502205863@qq.com>pull/5427/head
flybird11111
9 months ago
committed by
GitHub
22 changed files with 421 additions and 48 deletions
@ -0,0 +1,228 @@ |
|||||||
|
import argparse |
||||||
|
import resource |
||||||
|
from contextlib import nullcontext |
||||||
|
|
||||||
|
import torch |
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision |
||||||
|
from torch.optim import Adam |
||||||
|
from tqdm import tqdm |
||||||
|
from transformers.models.gpt2.configuration_gpt2 import GPT2Config |
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel |
||||||
|
|
||||||
|
import colossalai |
||||||
|
|
||||||
|
# import colossalai.utils.device as device_utils |
||||||
|
from colossalai.booster import Booster |
||||||
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin |
||||||
|
from colossalai.cluster import DistCoordinator |
||||||
|
from colossalai.lazy import LazyInitContext |
||||||
|
from colossalai.utils import get_current_device |
||||||
|
from examples.language.data_utils import RandomDataset |
||||||
|
from examples.language.model_utils import format_numel_str, get_model_numel |
||||||
|
from examples.language.performance_evaluator import PerformanceEvaluator |
||||||
|
|
||||||
|
# ============================== |
||||||
|
# Constants |
||||||
|
# ============================== |
||||||
|
MODEL_CONFIGS = { |
||||||
|
"118M": GPT2Config(activation_function="gelu"), |
||||||
|
"338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"), |
||||||
|
"738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"), |
||||||
|
"6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=4096, activation_function="gelu"), |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
def main(): |
||||||
|
# ============================== |
||||||
|
# Parse Arguments |
||||||
|
# ============================== |
||||||
|
parser = argparse.ArgumentParser() |
||||||
|
parser.add_argument("-c", "--config", type=str, default="6.21B", help="Model configuration") |
||||||
|
parser.add_argument( |
||||||
|
"-p", |
||||||
|
"--plugin", |
||||||
|
choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"], |
||||||
|
default="gemini", |
||||||
|
help="Choose which plugin to use", |
||||||
|
) |
||||||
|
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") |
||||||
|
parser.add_argument("-s", "--num_steps", type=int, default=200, help="Number of steps to run") |
||||||
|
parser.add_argument("-i", "--ignore_steps", type=int, default=3, help="Number of steps to ignore") |
||||||
|
parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") |
||||||
|
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") |
||||||
|
parser.add_argument( |
||||||
|
"-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto" |
||||||
|
) |
||||||
|
parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb") |
||||||
|
parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini") |
||||||
|
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") |
||||||
|
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") |
||||||
|
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") |
||||||
|
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") |
||||||
|
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") |
||||||
|
parser.add_argument("--mbs", type=int, default=1) |
||||||
|
parser.add_argument("--zero", type=int, default=0) |
||||||
|
parser.add_argument("--pp_style", type=str, default="1f1b") |
||||||
|
parser.add_argument("--num_model_chunks", type=int, default=2) |
||||||
|
parser.add_argument("--cpu_offload", action="store_true", help="Use gradient checkpointing") |
||||||
|
args = parser.parse_args() |
||||||
|
|
||||||
|
colossalai.launch_from_torch({}) |
||||||
|
coordinator = DistCoordinator() |
||||||
|
|
||||||
|
def empty_init(): |
||||||
|
pass |
||||||
|
|
||||||
|
# ============================== |
||||||
|
# Initialize Booster |
||||||
|
# ============================== |
||||||
|
use_empty_init = True |
||||||
|
if args.plugin == "gemini": |
||||||
|
plugin = GeminiPlugin( |
||||||
|
precision="bf16", |
||||||
|
shard_param_frac=args.shard_param_frac, |
||||||
|
offload_optim_frac=args.offload_optim_frac, |
||||||
|
offload_param_frac=args.offload_param_frac, |
||||||
|
tp_size=args.tp, |
||||||
|
extra_dp_size=args.extra_dp, |
||||||
|
) |
||||||
|
elif args.plugin == "gemini_auto": |
||||||
|
plugin = GeminiPlugin( |
||||||
|
placement_policy="auto", |
||||||
|
precision="bf16", |
||||||
|
warmup_non_model_data_ratio=args.warmup_ratio, |
||||||
|
tp_size=args.tp, |
||||||
|
extra_dp_size=args.extra_dp, |
||||||
|
) |
||||||
|
elif args.plugin == "fsdp": |
||||||
|
if use_empty_init: |
||||||
|
plugin = TorchFSDPPlugin( |
||||||
|
mixed_precision=MixedPrecision( |
||||||
|
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 |
||||||
|
), |
||||||
|
param_init_fn=empty_init(), |
||||||
|
) |
||||||
|
else: |
||||||
|
plugin = TorchFSDPPlugin( |
||||||
|
mixed_precision=MixedPrecision( |
||||||
|
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 |
||||||
|
) |
||||||
|
) |
||||||
|
elif args.plugin == "fsdp_cpu": |
||||||
|
if use_empty_init: |
||||||
|
plugin = TorchFSDPPlugin( |
||||||
|
mixed_precision=MixedPrecision( |
||||||
|
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 |
||||||
|
), |
||||||
|
cpu_offload=CPUOffload(offload_params=True), |
||||||
|
param_init_fn=empty_init(), |
||||||
|
) |
||||||
|
else: |
||||||
|
plugin = TorchFSDPPlugin( |
||||||
|
mixed_precision=MixedPrecision( |
||||||
|
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 |
||||||
|
), |
||||||
|
cpu_offload=CPUOffload(offload_params=True), |
||||||
|
) |
||||||
|
elif args.plugin == "3d": |
||||||
|
plugin = HybridParallelPlugin( |
||||||
|
tp_size=args.tp, |
||||||
|
pp_size=args.pp, |
||||||
|
pp_style=args.pp_style, |
||||||
|
zero_stage=args.zero, |
||||||
|
num_model_chunks=args.num_model_chunks, |
||||||
|
enable_all_optimization=True, |
||||||
|
num_microbatches=args.mbs, |
||||||
|
cpu_offload=args.cpu_offload, |
||||||
|
precision="bf16", |
||||||
|
) |
||||||
|
elif args.plugin == "3d_cpu": |
||||||
|
plugin = HybridParallelPlugin( |
||||||
|
tp_size=args.tp, |
||||||
|
pp_size=args.pp, |
||||||
|
zero_stage=args.zero, |
||||||
|
cpu_offload=True, |
||||||
|
enable_fused_normalization=torch.cuda.is_available(), |
||||||
|
num_microbatches=args.mbs, |
||||||
|
initial_scale=2**8, |
||||||
|
precision="bf16", |
||||||
|
) |
||||||
|
else: |
||||||
|
raise ValueError(f"Unknown plugin {args.plugin}") |
||||||
|
|
||||||
|
booster = Booster(plugin=plugin) |
||||||
|
|
||||||
|
# ============================== |
||||||
|
# Initialize Dataset and Dataloader |
||||||
|
# ============================== |
||||||
|
dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size |
||||||
|
|
||||||
|
config = MODEL_CONFIGS[args.config] |
||||||
|
dataset = RandomDataset( |
||||||
|
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size |
||||||
|
) |
||||||
|
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) |
||||||
|
|
||||||
|
# ============================== |
||||||
|
# Initialize Model and Optimizer |
||||||
|
# ============================== |
||||||
|
init_ctx = ( |
||||||
|
LazyInitContext(default_device=get_current_device()) |
||||||
|
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) |
||||||
|
else nullcontext() |
||||||
|
) |
||||||
|
|
||||||
|
with init_ctx: |
||||||
|
model = GPT2LMHeadModel(config) |
||||||
|
|
||||||
|
if args.grad_checkpoint: |
||||||
|
model.gradient_checkpointing_enable() |
||||||
|
|
||||||
|
model_numel = get_model_numel(model) |
||||||
|
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") |
||||||
|
performance_evaluator = PerformanceEvaluator( |
||||||
|
model_numel, |
||||||
|
model.config.n_layer, |
||||||
|
model.config.n_embd, |
||||||
|
model.config.vocab_size, |
||||||
|
args.grad_checkpoint, |
||||||
|
args.ignore_steps, |
||||||
|
dp_world_size=dp_size, |
||||||
|
) |
||||||
|
|
||||||
|
optimizer = Adam(model.parameters()) |
||||||
|
torch.set_default_dtype(torch.bfloat16) |
||||||
|
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) |
||||||
|
torch.set_default_dtype(torch.float) |
||||||
|
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") |
||||||
|
coordinator.print_on_master( |
||||||
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" |
||||||
|
) |
||||||
|
|
||||||
|
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: |
||||||
|
data_iter = iter(dataloader) |
||||||
|
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): |
||||||
|
performance_evaluator.on_step_start(step) |
||||||
|
booster.execute_pipeline( |
||||||
|
data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False |
||||||
|
) |
||||||
|
optimizer.step() |
||||||
|
optimizer.zero_grad() |
||||||
|
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) |
||||||
|
else: |
||||||
|
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): |
||||||
|
performance_evaluator.on_step_start(step) |
||||||
|
outputs = model(**batch) |
||||||
|
loss = outputs[0] |
||||||
|
booster.backward(loss, optimizer) |
||||||
|
optimizer.step() |
||||||
|
optimizer.zero_grad() |
||||||
|
performance_evaluator.on_step_end(**batch) |
||||||
|
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") |
||||||
|
|
||||||
|
performance_evaluator.on_fit_end() |
||||||
|
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
main() |
Loading…
Reference in new issue