mirror of https://github.com/hpcaitech/ColossalAI
[example] update gemini examples (#3868)
* [example]update gemini examples * [example]update gemini examplespull/3847/head^2
parent
2506e275b8
commit
5f79008c4a
|
@ -3,7 +3,7 @@ $(cd `dirname $0`;pwd)
|
|||
export TRAIN_STEP=4
|
||||
|
||||
for MODEL_TYPE in "gpt2_medium"; do
|
||||
for DISTPLAN in "colossalai"; do
|
||||
for DISTPLAN in "CAI_Gemini"; do
|
||||
for BATCH_SIZE in 2; do
|
||||
for GPUNUM in 1 4; do
|
||||
for TPDEGREE in 1 2; do
|
||||
|
|
|
@ -11,11 +11,13 @@ from packaging import version
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero import ColoInitContext
|
||||
|
||||
CAI_VERSION = colossalai.__version__
|
||||
|
||||
|
@ -236,23 +238,6 @@ def main():
|
|||
tensor_parallelize(model, tp_pg)
|
||||
|
||||
# asign running configurations
|
||||
gemini_config = None
|
||||
if args.distplan.startswith("CAI_ZeRO"):
|
||||
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
|
||||
elif args.distplan == "CAI_Gemini":
|
||||
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
|
||||
device=get_current_device(),
|
||||
placement_policy=args.placement,
|
||||
pin_memory=True,
|
||||
hidden_dim=model.config.n_embd,
|
||||
search_range_mb=128)
|
||||
optim_config = dict(gpu_margin_mem_ratio=0.)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# build a highly optimized gpu/cpu optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
|
||||
if args.distplan == "CAI_ZeRO1":
|
||||
zero_stage = 1
|
||||
elif args.distplan == "CAI_ZeRO2":
|
||||
|
@ -262,22 +247,42 @@ def main():
|
|||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# wrap your model and optimizer
|
||||
model = zero_model_wrapper(model, zero_stage, gemini_config)
|
||||
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
|
||||
plugin = None
|
||||
if args.distplan.startswith("CAI_ZeRO"):
|
||||
plugin = LowLevelZeroPlugin(stage=zero_stage,
|
||||
reduce_bucket_size_in_m=12 * 1024 * 1024,
|
||||
overlap_communication=True,
|
||||
verbose=True)
|
||||
elif args.distplan == "CAI_Gemini":
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy=args.placement,
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=args.tp_degree == 1,
|
||||
search_range_mb=128,
|
||||
hidden_dim=model.config.n_embd,
|
||||
gpu_margin_mem_ratio=0.)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# build a highly optimized gpu/cpu optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
elif args.distplan.startswith("Pytorch"):
|
||||
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
|
||||
model = model_builder(args.model_type)(checkpoint=True).cuda()
|
||||
model = DDP(model)
|
||||
plugin = TorchDDPPlugin()
|
||||
if args.distplan.endswith("DDP"):
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
elif args.distplan.endswith("ZeRO"):
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
|
||||
|
||||
else:
|
||||
raise RuntimeError
|
||||
# wrap your model and optimizer
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
# model is shared after TP
|
||||
numel = get_model_size(model)
|
||||
|
@ -305,13 +310,7 @@ def main():
|
|||
fwd_end = time()
|
||||
fwd_time = fwd_end - start
|
||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||
|
||||
if args.distplan.startswith("CAI"):
|
||||
optimizer.backward(loss)
|
||||
elif args.distplan.startswith("Pytorch"):
|
||||
loss.backward()
|
||||
else:
|
||||
raise RuntimeError
|
||||
booster.backward(loss, optimizer)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time()
|
||||
|
|
Loading…
Reference in New Issue