Browse Source

[example] update gemini examples (#3868)

* [example]update gemini examples

* [example]update gemini examples
pull/3847/head^2
jiangmingyan 2 years ago committed by GitHub
parent
commit
5f79008c4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      examples/language/gpt/gemini/test_ci.sh
  2. 57
      examples/language/gpt/gemini/train_gpt_demo.py

2
examples/language/gpt/gemini/test_ci.sh

@ -3,7 +3,7 @@ $(cd `dirname $0`;pwd)
export TRAIN_STEP=4
for MODEL_TYPE in "gpt2_medium"; do
for DISTPLAN in "colossalai"; do
for DISTPLAN in "CAI_Gemini"; do
for BATCH_SIZE in 2; do
for GPUNUM in 1 4; do
for TPDEGREE in 1 2; do

57
examples/language/gpt/gemini/train_gpt_demo.py

@ -11,11 +11,13 @@ from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import ColoInitContext
CAI_VERSION = colossalai.__version__
@ -236,23 +238,6 @@ def main():
tensor_parallelize(model, tp_pg)
# asign running configurations
gemini_config = None
if args.distplan.startswith("CAI_ZeRO"):
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
elif args.distplan == "CAI_Gemini":
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
hidden_dim=model.config.n_embd,
search_range_mb=128)
optim_config = dict(gpu_margin_mem_ratio=0.)
else:
raise RuntimeError
# build a highly optimized gpu/cpu optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
if args.distplan == "CAI_ZeRO1":
zero_stage = 1
elif args.distplan == "CAI_ZeRO2":
@ -262,22 +247,42 @@ def main():
else:
raise RuntimeError
# wrap your model and optimizer
model = zero_model_wrapper(model, zero_stage, gemini_config)
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
plugin = None
if args.distplan.startswith("CAI_ZeRO"):
plugin = LowLevelZeroPlugin(stage=zero_stage,
reduce_bucket_size_in_m=12 * 1024 * 1024,
overlap_communication=True,
verbose=True)
elif args.distplan == "CAI_Gemini":
plugin = GeminiPlugin(device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
strict_ddp_mode=args.tp_degree == 1,
search_range_mb=128,
hidden_dim=model.config.n_embd,
gpu_margin_mem_ratio=0.)
else:
raise RuntimeError
# build a highly optimized gpu/cpu optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
elif args.distplan.startswith("Pytorch"):
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
model = model_builder(args.model_type)(checkpoint=True).cuda()
model = DDP(model)
plugin = TorchDDPPlugin()
if args.distplan.endswith("DDP"):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
elif args.distplan.endswith("ZeRO"):
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
else:
raise RuntimeError
# wrap your model and optimizer
booster = Booster(plugin=plugin)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
# model is shared after TP
numel = get_model_size(model)
@ -305,13 +310,7 @@ def main():
fwd_end = time()
fwd_time = fwd_end - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
if args.distplan.startswith("CAI"):
optimizer.backward(loss)
elif args.distplan.startswith("Pytorch"):
loss.backward()
else:
raise RuntimeError
booster.backward(loss, optimizer)
torch.cuda.synchronize()
bwd_end = time()

Loading…
Cancel
Save