mirror of https://github.com/hpcaitech/ColossalAI
[example] Modify palm example with the new booster API (#3913)
* Modify torch version requirement to adapt torch 2.0 * modify palm example using new booster API * roll back * fix port * polish * polishpull/3923/head
parent
a55fb00c18
commit
b306cecf28
|
@ -43,6 +43,9 @@ palm = PaLM(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## New API
|
||||||
|
We have modified our previous implementation of PaLM with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in train.py. We have also offer a shell script test_ci.sh for you to go through all our plugins for the booster. For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/.
|
||||||
|
|
||||||
## Test on Enwik8
|
## Test on Enwik8
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -3,9 +3,11 @@ export DISTPAN="colossalai"
|
||||||
|
|
||||||
# The following options only valid when DISTPAN="colossalai"
|
# The following options only valid when DISTPAN="colossalai"
|
||||||
export TPDEGREE=1
|
export TPDEGREE=1
|
||||||
export GPUNUM=1
|
export GPUNUM=4
|
||||||
export PLACEMENT='cpu'
|
export PLACEMENT='cpu'
|
||||||
export USE_SHARD_INIT=False
|
export USE_SHARD_INIT=False
|
||||||
export BATCH_SIZE=4
|
export BATCH_SIZE=1
|
||||||
|
|
||||||
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
|
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py \
|
||||||
|
--dummy_data=True --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --plugin='gemini' \
|
||||||
|
--placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
|
||||||
|
|
|
@ -4,6 +4,6 @@ for BATCH_SIZE in 2
|
||||||
do
|
do
|
||||||
for GPUNUM in 1 4
|
for GPUNUM in 1 4
|
||||||
do
|
do
|
||||||
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log
|
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --standalone train.py --dummy_data=True --batch_size=${BATCH_SIZE} --plugin='gemini' 2>&1 | tee run.log
|
||||||
done
|
done
|
||||||
done
|
done
|
||||||
|
|
|
@ -9,6 +9,8 @@ import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import tqdm
|
import tqdm
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from colossalai.nn import HybridAdam
|
||||||
from palm_pytorch import PaLM
|
from palm_pytorch import PaLM
|
||||||
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
@ -18,6 +20,8 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||||
from colossalai.utils import MultiTimer, get_current_device
|
from colossalai.utils import MultiTimer, get_current_device
|
||||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
|
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
|
|
||||||
|
@ -58,6 +62,12 @@ def parse_args():
|
||||||
help=
|
help=
|
||||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument('-p',
|
||||||
|
'--plugin',
|
||||||
|
type=str,
|
||||||
|
default='torch_ddp',
|
||||||
|
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
||||||
|
help="plugin to use")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch_size",
|
"--batch_size",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -101,28 +111,6 @@ def get_model_size(model: nn.Module):
|
||||||
return total_numel
|
return total_numel
|
||||||
|
|
||||||
|
|
||||||
# Gemini + ZeRO DDP
|
|
||||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
|
|
||||||
cai_version = colossalai.__version__
|
|
||||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
|
||||||
from colossalai.nn.parallel import GeminiDDP
|
|
||||||
model = GeminiDDP(model,
|
|
||||||
device=get_current_device(),
|
|
||||||
placement_policy=placement_policy,
|
|
||||||
pin_memory=True,
|
|
||||||
search_range_mb=32)
|
|
||||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
|
||||||
from colossalai.gemini import ChunkManager, GeminiManager
|
|
||||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
|
||||||
chunk_manager = ChunkManager(chunk_size,
|
|
||||||
pg,
|
|
||||||
enable_distributed_storage=True,
|
|
||||||
init_device=GeminiManager.get_default_device(placement_policy))
|
|
||||||
model = ZeroDDP(model, gemini_manager)
|
|
||||||
else:
|
|
||||||
raise NotImplemented(f"CAI version {cai_version} is not supported")
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
# Parameter Sharding Strategies for Tensor Parallelism
|
# Parameter Sharding Strategies for Tensor Parallelism
|
||||||
|
@ -218,6 +206,18 @@ val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size))
|
||||||
if args.distplan == "colossalai":
|
if args.distplan == "colossalai":
|
||||||
# instantiate GPT-like decoder model
|
# instantiate GPT-like decoder model
|
||||||
|
|
||||||
|
booster_kwargs = {}
|
||||||
|
if args.plugin == 'torch_ddp_fp16':
|
||||||
|
booster_kwargs['mixed_precision'] = 'fp16'
|
||||||
|
if args.plugin.startswith('torch_ddp'):
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
elif args.plugin == 'gemini':
|
||||||
|
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||||
|
elif args.plugin == 'low_level_zero':
|
||||||
|
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||||
|
logger.info(f"plugin: {plugin}")
|
||||||
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
|
||||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||||
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
|
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
|
||||||
|
@ -228,12 +228,12 @@ if args.distplan == "colossalai":
|
||||||
|
|
||||||
pg = default_pg
|
pg = default_pg
|
||||||
tensor_parallelize(model, pg)
|
tensor_parallelize(model, pg)
|
||||||
model = gemini_zero_dpp(model, pg, args.placement)
|
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
|
|
||||||
#optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5)
|
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
|
||||||
optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5)
|
model, optimizer, _, _, _ = booster.boost(model, optimizer)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model = PaLM(num_tokens=256, dim=512, depth=8)
|
model = PaLM(num_tokens=256, dim=512, depth=8)
|
||||||
model = AutoregressiveWrapper(model, max_seq_len=2048)
|
model = AutoregressiveWrapper(model, max_seq_len=2048)
|
||||||
|
|
Loading…
Reference in New Issue