|
|
|
@ -11,11 +11,13 @@ from packaging import version
|
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.booster import Booster |
|
|
|
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin |
|
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger |
|
|
|
|
from colossalai.nn.optimizer import HybridAdam |
|
|
|
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec |
|
|
|
|
from colossalai.utils import get_current_device |
|
|
|
|
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper |
|
|
|
|
from colossalai.zero import ColoInitContext |
|
|
|
|
|
|
|
|
|
CAI_VERSION = colossalai.__version__ |
|
|
|
|
|
|
|
|
@ -236,23 +238,6 @@ def main():
|
|
|
|
|
tensor_parallelize(model, tp_pg) |
|
|
|
|
|
|
|
|
|
# asign running configurations |
|
|
|
|
gemini_config = None |
|
|
|
|
if args.distplan.startswith("CAI_ZeRO"): |
|
|
|
|
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) |
|
|
|
|
elif args.distplan == "CAI_Gemini": |
|
|
|
|
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, |
|
|
|
|
device=get_current_device(), |
|
|
|
|
placement_policy=args.placement, |
|
|
|
|
pin_memory=True, |
|
|
|
|
hidden_dim=model.config.n_embd, |
|
|
|
|
search_range_mb=128) |
|
|
|
|
optim_config = dict(gpu_margin_mem_ratio=0.) |
|
|
|
|
else: |
|
|
|
|
raise RuntimeError |
|
|
|
|
|
|
|
|
|
# build a highly optimized gpu/cpu optimizer |
|
|
|
|
optimizer = HybridAdam(model.parameters(), lr=1e-3) |
|
|
|
|
|
|
|
|
|
if args.distplan == "CAI_ZeRO1": |
|
|
|
|
zero_stage = 1 |
|
|
|
|
elif args.distplan == "CAI_ZeRO2": |
|
|
|
@ -262,22 +247,42 @@ def main():
|
|
|
|
|
else: |
|
|
|
|
raise RuntimeError |
|
|
|
|
|
|
|
|
|
# wrap your model and optimizer |
|
|
|
|
model = zero_model_wrapper(model, zero_stage, gemini_config) |
|
|
|
|
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) |
|
|
|
|
plugin = None |
|
|
|
|
if args.distplan.startswith("CAI_ZeRO"): |
|
|
|
|
plugin = LowLevelZeroPlugin(stage=zero_stage, |
|
|
|
|
reduce_bucket_size_in_m=12 * 1024 * 1024, |
|
|
|
|
overlap_communication=True, |
|
|
|
|
verbose=True) |
|
|
|
|
elif args.distplan == "CAI_Gemini": |
|
|
|
|
plugin = GeminiPlugin(device=get_current_device(), |
|
|
|
|
placement_policy=args.placement, |
|
|
|
|
pin_memory=True, |
|
|
|
|
strict_ddp_mode=args.tp_degree == 1, |
|
|
|
|
search_range_mb=128, |
|
|
|
|
hidden_dim=model.config.n_embd, |
|
|
|
|
gpu_margin_mem_ratio=0.) |
|
|
|
|
else: |
|
|
|
|
raise RuntimeError |
|
|
|
|
|
|
|
|
|
# build a highly optimized gpu/cpu optimizer |
|
|
|
|
optimizer = HybridAdam(model.parameters(), lr=1e-3) |
|
|
|
|
|
|
|
|
|
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) |
|
|
|
|
elif args.distplan.startswith("Pytorch"): |
|
|
|
|
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." |
|
|
|
|
model = model_builder(args.model_type)(checkpoint=True).cuda() |
|
|
|
|
model = DDP(model) |
|
|
|
|
plugin = TorchDDPPlugin() |
|
|
|
|
if args.distplan.endswith("DDP"): |
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
|
|
|
|
elif args.distplan.endswith("ZeRO"): |
|
|
|
|
from torch.distributed.optim import ZeroRedundancyOptimizer |
|
|
|
|
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3) |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
raise RuntimeError |
|
|
|
|
# wrap your model and optimizer |
|
|
|
|
booster = Booster(plugin=plugin) |
|
|
|
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) |
|
|
|
|
|
|
|
|
|
# model is shared after TP |
|
|
|
|
numel = get_model_size(model) |
|
|
|
@ -305,13 +310,7 @@ def main():
|
|
|
|
|
fwd_end = time() |
|
|
|
|
fwd_time = fwd_end - start |
|
|
|
|
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) |
|
|
|
|
|
|
|
|
|
if args.distplan.startswith("CAI"): |
|
|
|
|
optimizer.backward(loss) |
|
|
|
|
elif args.distplan.startswith("Pytorch"): |
|
|
|
|
loss.backward() |
|
|
|
|
else: |
|
|
|
|
raise RuntimeError |
|
|
|
|
booster.backward(loss, optimizer) |
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
bwd_end = time() |
|
|
|
|