[example] Modify palm example with the new booster API (#3913)

* Modify torch version requirement to adapt torch 2.0

* modify palm example using new booster API

* roll back

* fix port

* polish

* polish
pull/3923/head
Liu Ziming 2023-06-07 16:05:00 +08:00 committed by GitHub
parent a55fb00c18
commit b306cecf28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 29 deletions

View File

@ -43,6 +43,9 @@ palm = PaLM(
) )
``` ```
## New API
We have modified our previous implementation of PaLM with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in train.py. We have also offer a shell script test_ci.sh for you to go through all our plugins for the booster. For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/.
## Test on Enwik8 ## Test on Enwik8
```bash ```bash

View File

@ -3,9 +3,11 @@ export DISTPAN="colossalai"
# The following options only valid when DISTPAN="colossalai" # The following options only valid when DISTPAN="colossalai"
export TPDEGREE=1 export TPDEGREE=1
export GPUNUM=1 export GPUNUM=4
export PLACEMENT='cpu' export PLACEMENT='cpu'
export USE_SHARD_INIT=False export USE_SHARD_INIT=False
export BATCH_SIZE=4 export BATCH_SIZE=1
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py \
--dummy_data=True --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --plugin='gemini' \
--placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log

View File

@ -4,6 +4,6 @@ for BATCH_SIZE in 2
do do
for GPUNUM in 1 4 for GPUNUM in 1 4
do do
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --standalone train.py --dummy_data=True --batch_size=${BATCH_SIZE} --plugin='gemini' 2>&1 | tee run.log
done done
done done

View File

@ -9,6 +9,8 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from packaging import version from packaging import version
from colossalai.nn import HybridAdam
from palm_pytorch import PaLM from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
@ -18,6 +20,8 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device from colossalai.utils import MultiTimer, get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
# constants # constants
@ -58,6 +62,12 @@ def parse_args():
help= help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
) )
parser.add_argument('-p',
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
help="plugin to use")
parser.add_argument( parser.add_argument(
"--batch_size", "--batch_size",
type=int, type=int,
@ -101,28 +111,6 @@ def get_model_size(model: nn.Module):
return total_numel return total_numel
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placement_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
# Parameter Sharding Strategies for Tensor Parallelism # Parameter Sharding Strategies for Tensor Parallelism
@ -218,6 +206,18 @@ val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size))
if args.distplan == "colossalai": if args.distplan == "colossalai":
# instantiate GPT-like decoder model # instantiate GPT-like decoder model
booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
logger.info(f"plugin: {plugin}")
booster = Booster(plugin=plugin, **booster_kwargs)
default_pg = ProcessGroup(tp_degree=args.tp_degree) default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
@ -228,12 +228,12 @@ if args.distplan == "colossalai":
pg = default_pg pg = default_pg
tensor_parallelize(model, pg) tensor_parallelize(model, pg)
model = gemini_zero_dpp(model, pg, args.placement)
# optimizer # optimizer
#optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5) model, optimizer, _, _, _ = booster.boost(model, optimizer)
else: else:
model = PaLM(num_tokens=256, dim=512, depth=8) model = PaLM(num_tokens=256, dim=512, depth=8)
model = AutoregressiveWrapper(model, max_seq_len=2048) model = AutoregressiveWrapper(model, max_seq_len=2048)