diff --git a/examples/language/gpt/gemini/test_ci.sh b/examples/language/gpt/gemini/test_ci.sh index 6079d5ed6..0ddfd3a62 100644 --- a/examples/language/gpt/gemini/test_ci.sh +++ b/examples/language/gpt/gemini/test_ci.sh @@ -3,7 +3,7 @@ $(cd `dirname $0`;pwd) export TRAIN_STEP=4 for MODEL_TYPE in "gpt2_medium"; do - for DISTPLAN in "colossalai"; do + for DISTPLAN in "CAI_Gemini"; do for BATCH_SIZE in 2; do for GPUNUM in 1 4; do for TPDEGREE in 1 2; do diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index b2a7fa36d..92751c7e2 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -11,11 +11,13 @@ from packaging import version from torch.nn.parallel import DistributedDataParallel as DDP import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import ColoInitContext CAI_VERSION = colossalai.__version__ @@ -236,23 +238,6 @@ def main(): tensor_parallelize(model, tp_pg) # asign running configurations - gemini_config = None - if args.distplan.startswith("CAI_ZeRO"): - optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) - elif args.distplan == "CAI_Gemini": - gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, - device=get_current_device(), - placement_policy=args.placement, - pin_memory=True, - hidden_dim=model.config.n_embd, - search_range_mb=128) - optim_config = dict(gpu_margin_mem_ratio=0.) - else: - raise RuntimeError - - # build a highly optimized gpu/cpu optimizer - optimizer = HybridAdam(model.parameters(), lr=1e-3) - if args.distplan == "CAI_ZeRO1": zero_stage = 1 elif args.distplan == "CAI_ZeRO2": @@ -262,22 +247,42 @@ def main(): else: raise RuntimeError - # wrap your model and optimizer - model = zero_model_wrapper(model, zero_stage, gemini_config) - optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) + plugin = None + if args.distplan.startswith("CAI_ZeRO"): + plugin = LowLevelZeroPlugin(stage=zero_stage, + reduce_bucket_size_in_m=12 * 1024 * 1024, + overlap_communication=True, + verbose=True) + elif args.distplan == "CAI_Gemini": + plugin = GeminiPlugin(device=get_current_device(), + placement_policy=args.placement, + pin_memory=True, + strict_ddp_mode=args.tp_degree == 1, + search_range_mb=128, + hidden_dim=model.config.n_embd, + gpu_margin_mem_ratio=0.) + else: + raise RuntimeError + + # build a highly optimized gpu/cpu optimizer + optimizer = HybridAdam(model.parameters(), lr=1e-3) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) elif args.distplan.startswith("Pytorch"): assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." model = model_builder(args.model_type)(checkpoint=True).cuda() - model = DDP(model) + plugin = TorchDDPPlugin() if args.distplan.endswith("DDP"): optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) elif args.distplan.endswith("ZeRO"): from torch.distributed.optim import ZeroRedundancyOptimizer optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3) + else: raise RuntimeError + # wrap your model and optimizer + booster = Booster(plugin=plugin) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) # model is shared after TP numel = get_model_size(model) @@ -305,13 +310,7 @@ def main(): fwd_end = time() fwd_time = fwd_end - start logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) - - if args.distplan.startswith("CAI"): - optimizer.backward(loss) - elif args.distplan.startswith("Pytorch"): - loss.backward() - else: - raise RuntimeError + booster.backward(loss, optimizer) torch.cuda.synchronize() bwd_end = time()