mirror of https://github.com/hpcaitech/ColossalAI
[example] Modify palm example with the new booster API (#3913)
* Modify torch version requirement to adapt torch 2.0 * modify palm example using new booster API * roll back * fix port * polish * polishpull/3923/head
parent
a55fb00c18
commit
b306cecf28
|
@ -43,6 +43,9 @@ palm = PaLM(
|
|||
)
|
||||
```
|
||||
|
||||
## New API
|
||||
We have modified our previous implementation of PaLM with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in train.py. We have also offer a shell script test_ci.sh for you to go through all our plugins for the booster. For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/.
|
||||
|
||||
## Test on Enwik8
|
||||
|
||||
```bash
|
||||
|
|
|
@ -3,9 +3,11 @@ export DISTPAN="colossalai"
|
|||
|
||||
# The following options only valid when DISTPAN="colossalai"
|
||||
export TPDEGREE=1
|
||||
export GPUNUM=1
|
||||
export GPUNUM=4
|
||||
export PLACEMENT='cpu'
|
||||
export USE_SHARD_INIT=False
|
||||
export BATCH_SIZE=4
|
||||
export BATCH_SIZE=1
|
||||
|
||||
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
|
||||
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py \
|
||||
--dummy_data=True --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --plugin='gemini' \
|
||||
--placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
|
||||
|
|
|
@ -4,6 +4,6 @@ for BATCH_SIZE in 2
|
|||
do
|
||||
for GPUNUM in 1 4
|
||||
do
|
||||
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log
|
||||
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --standalone train.py --dummy_data=True --batch_size=${BATCH_SIZE} --plugin='gemini' 2>&1 | tee run.log
|
||||
done
|
||||
done
|
||||
|
|
|
@ -9,6 +9,8 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
import tqdm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.nn import HybridAdam
|
||||
from palm_pytorch import PaLM
|
||||
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
@ -18,6 +20,8 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import MultiTimer, get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
|
||||
# constants
|
||||
|
||||
|
@ -58,6 +62,12 @@ def parse_args():
|
|||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
type=str,
|
||||
default='torch_ddp',
|
||||
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
||||
help="plugin to use")
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
|
@ -101,28 +111,6 @@ def get_model_size(model: nn.Module):
|
|||
return total_numel
|
||||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
|
||||
cai_version = colossalai.__version__
|
||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
search_range_mb=32)
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=True,
|
||||
init_device=GeminiManager.get_default_device(placement_policy))
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
else:
|
||||
raise NotImplemented(f"CAI version {cai_version} is not supported")
|
||||
return model
|
||||
|
||||
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
|
@ -218,6 +206,18 @@ val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size))
|
|||
if args.distplan == "colossalai":
|
||||
# instantiate GPT-like decoder model
|
||||
|
||||
booster_kwargs = {}
|
||||
if args.plugin == 'torch_ddp_fp16':
|
||||
booster_kwargs['mixed_precision'] = 'fp16'
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||
logger.info(f"plugin: {plugin}")
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
|
||||
|
@ -228,12 +228,12 @@ if args.distplan == "colossalai":
|
|||
|
||||
pg = default_pg
|
||||
tensor_parallelize(model, pg)
|
||||
model = gemini_zero_dpp(model, pg, args.placement)
|
||||
|
||||
# optimizer
|
||||
|
||||
#optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5)
|
||||
optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5)
|
||||
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
|
||||
model, optimizer, _, _, _ = booster.boost(model, optimizer)
|
||||
|
||||
else:
|
||||
model = PaLM(num_tokens=256, dim=512, depth=8)
|
||||
model = AutoregressiveWrapper(model, max_seq_len=2048)
|
||||
|
|
Loading…
Reference in New Issue