diff --git a/examples/language/palm/README.md b/examples/language/palm/README.md index 486bf240f..3ff3939d6 100644 --- a/examples/language/palm/README.md +++ b/examples/language/palm/README.md @@ -43,6 +43,9 @@ palm = PaLM( ) ``` +## New API +We have modified our previous implementation of PaLM with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in train.py. We have also offer a shell script test_ci.sh for you to go through all our plugins for the booster. For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/. + ## Test on Enwik8 ```bash diff --git a/examples/language/palm/run.sh b/examples/language/palm/run.sh index 7a533509e..2a846e81a 100644 --- a/examples/language/palm/run.sh +++ b/examples/language/palm/run.sh @@ -3,9 +3,11 @@ export DISTPAN="colossalai" # The following options only valid when DISTPAN="colossalai" export TPDEGREE=1 -export GPUNUM=1 +export GPUNUM=4 export PLACEMENT='cpu' export USE_SHARD_INIT=False -export BATCH_SIZE=4 +export BATCH_SIZE=1 -env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log +env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py \ +--dummy_data=True --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --plugin='gemini' \ +--placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log diff --git a/examples/language/palm/test_ci.sh b/examples/language/palm/test_ci.sh index f21095578..4de6a44e5 100644 --- a/examples/language/palm/test_ci.sh +++ b/examples/language/palm/test_ci.sh @@ -4,6 +4,6 @@ for BATCH_SIZE in 2 do for GPUNUM in 1 4 do -env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log +env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --standalone train.py --dummy_data=True --batch_size=${BATCH_SIZE} --plugin='gemini' 2>&1 | tee run.log done done diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index b16da1c77..62062e8bd 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -9,6 +9,8 @@ import torch.nn as nn import torch.optim as optim import tqdm from packaging import version + +from colossalai.nn import HybridAdam from palm_pytorch import PaLM from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from torch.utils.data import DataLoader, Dataset @@ -18,6 +20,8 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import MultiTimer, get_current_device from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin # constants @@ -58,6 +62,12 @@ def parse_args(): help= "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", ) + parser.add_argument('-p', + '--plugin', + type=str, + default='torch_ddp', + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], + help="plugin to use") parser.add_argument( "--batch_size", type=int, @@ -101,28 +111,6 @@ def get_model_size(model: nn.Module): return total_numel -# Gemini + ZeRO DDP -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): - cai_version = colossalai.__version__ - if version.parse(cai_version) > version.parse("0.1.10"): - from colossalai.nn.parallel import GeminiDDP - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placement_policy, - pin_memory=True, - search_range_mb=32) - elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): - from colossalai.gemini import ChunkManager, GeminiManager - chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(placement_policy)) - model = ZeroDDP(model, gemini_manager) - else: - raise NotImplemented(f"CAI version {cai_version} is not supported") - return model # Parameter Sharding Strategies for Tensor Parallelism @@ -218,6 +206,18 @@ val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size)) if args.distplan == "colossalai": # instantiate GPT-like decoder model + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) + logger.info(f"plugin: {plugin}") + booster = Booster(plugin=plugin, **booster_kwargs) + default_pg = ProcessGroup(tp_degree=args.tp_degree) default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) @@ -228,12 +228,12 @@ if args.distplan == "colossalai": pg = default_pg tensor_parallelize(model, pg) - model = gemini_zero_dpp(model, pg, args.placement) # optimizer - #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) - optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5) + optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5) + model, optimizer, _, _, _ = booster.boost(model, optimizer) + else: model = PaLM(num_tokens=256, dim=512, depth=8) model = AutoregressiveWrapper(model, max_seq_len=2048)