[example] update gemini examples (#3868)

* [example]update gemini examples

* [example]update gemini examples
pull/3847/head^2
jiangmingyan 2023-05-30 18:41:41 +08:00 committed by GitHub
parent 2506e275b8
commit 5f79008c4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 30 deletions

View File

@ -3,7 +3,7 @@ $(cd `dirname $0`;pwd)
export TRAIN_STEP=4 export TRAIN_STEP=4
for MODEL_TYPE in "gpt2_medium"; do for MODEL_TYPE in "gpt2_medium"; do
for DISTPLAN in "colossalai"; do for DISTPLAN in "CAI_Gemini"; do
for BATCH_SIZE in 2; do for BATCH_SIZE in 2; do
for GPUNUM in 1 4; do for GPUNUM in 1 4; do
for TPDEGREE in 1 2; do for TPDEGREE in 1 2; do

View File

@ -11,11 +11,13 @@ from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import ColoInitContext
CAI_VERSION = colossalai.__version__ CAI_VERSION = colossalai.__version__
@ -236,23 +238,6 @@ def main():
tensor_parallelize(model, tp_pg) tensor_parallelize(model, tp_pg)
# asign running configurations # asign running configurations
gemini_config = None
if args.distplan.startswith("CAI_ZeRO"):
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
elif args.distplan == "CAI_Gemini":
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
hidden_dim=model.config.n_embd,
search_range_mb=128)
optim_config = dict(gpu_margin_mem_ratio=0.)
else:
raise RuntimeError
# build a highly optimized gpu/cpu optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
if args.distplan == "CAI_ZeRO1": if args.distplan == "CAI_ZeRO1":
zero_stage = 1 zero_stage = 1
elif args.distplan == "CAI_ZeRO2": elif args.distplan == "CAI_ZeRO2":
@ -262,22 +247,42 @@ def main():
else: else:
raise RuntimeError raise RuntimeError
# wrap your model and optimizer plugin = None
model = zero_model_wrapper(model, zero_stage, gemini_config) if args.distplan.startswith("CAI_ZeRO"):
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) plugin = LowLevelZeroPlugin(stage=zero_stage,
reduce_bucket_size_in_m=12 * 1024 * 1024,
overlap_communication=True,
verbose=True)
elif args.distplan == "CAI_Gemini":
plugin = GeminiPlugin(device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
strict_ddp_mode=args.tp_degree == 1,
search_range_mb=128,
hidden_dim=model.config.n_embd,
gpu_margin_mem_ratio=0.)
else:
raise RuntimeError
# build a highly optimized gpu/cpu optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
elif args.distplan.startswith("Pytorch"): elif args.distplan.startswith("Pytorch"):
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
model = model_builder(args.model_type)(checkpoint=True).cuda() model = model_builder(args.model_type)(checkpoint=True).cuda()
model = DDP(model) plugin = TorchDDPPlugin()
if args.distplan.endswith("DDP"): if args.distplan.endswith("DDP"):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
elif args.distplan.endswith("ZeRO"): elif args.distplan.endswith("ZeRO"):
from torch.distributed.optim import ZeroRedundancyOptimizer from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3) optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
else: else:
raise RuntimeError raise RuntimeError
# wrap your model and optimizer
booster = Booster(plugin=plugin)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
# model is shared after TP # model is shared after TP
numel = get_model_size(model) numel = get_model_size(model)
@ -305,13 +310,7 @@ def main():
fwd_end = time() fwd_end = time()
fwd_time = fwd_end - start fwd_time = fwd_end - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
booster.backward(loss, optimizer)
if args.distplan.startswith("CAI"):
optimizer.backward(loss)
elif args.distplan.startswith("Pytorch"):
loss.backward()
else:
raise RuntimeError
torch.cuda.synchronize() torch.cuda.synchronize()
bwd_end = time() bwd_end = time()