mirror of https://github.com/hpcaitech/ColossalAI
[example] update gemini examples (#3868)
* [example]update gemini examples * [example]update gemini examplespull/3847/head^2
parent
2506e275b8
commit
5f79008c4a
|
@ -3,7 +3,7 @@ $(cd `dirname $0`;pwd)
|
||||||
export TRAIN_STEP=4
|
export TRAIN_STEP=4
|
||||||
|
|
||||||
for MODEL_TYPE in "gpt2_medium"; do
|
for MODEL_TYPE in "gpt2_medium"; do
|
||||||
for DISTPLAN in "colossalai"; do
|
for DISTPLAN in "CAI_Gemini"; do
|
||||||
for BATCH_SIZE in 2; do
|
for BATCH_SIZE in 2; do
|
||||||
for GPUNUM in 1 4; do
|
for GPUNUM in 1 4; do
|
||||||
for TPDEGREE in 1 2; do
|
for TPDEGREE in 1 2; do
|
||||||
|
|
|
@ -11,11 +11,13 @@ from packaging import version
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
|
from colossalai.zero import ColoInitContext
|
||||||
|
|
||||||
CAI_VERSION = colossalai.__version__
|
CAI_VERSION = colossalai.__version__
|
||||||
|
|
||||||
|
@ -236,23 +238,6 @@ def main():
|
||||||
tensor_parallelize(model, tp_pg)
|
tensor_parallelize(model, tp_pg)
|
||||||
|
|
||||||
# asign running configurations
|
# asign running configurations
|
||||||
gemini_config = None
|
|
||||||
if args.distplan.startswith("CAI_ZeRO"):
|
|
||||||
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
|
|
||||||
elif args.distplan == "CAI_Gemini":
|
|
||||||
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
|
|
||||||
device=get_current_device(),
|
|
||||||
placement_policy=args.placement,
|
|
||||||
pin_memory=True,
|
|
||||||
hidden_dim=model.config.n_embd,
|
|
||||||
search_range_mb=128)
|
|
||||||
optim_config = dict(gpu_margin_mem_ratio=0.)
|
|
||||||
else:
|
|
||||||
raise RuntimeError
|
|
||||||
|
|
||||||
# build a highly optimized gpu/cpu optimizer
|
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
|
||||||
|
|
||||||
if args.distplan == "CAI_ZeRO1":
|
if args.distplan == "CAI_ZeRO1":
|
||||||
zero_stage = 1
|
zero_stage = 1
|
||||||
elif args.distplan == "CAI_ZeRO2":
|
elif args.distplan == "CAI_ZeRO2":
|
||||||
|
@ -262,22 +247,42 @@ def main():
|
||||||
else:
|
else:
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|
||||||
# wrap your model and optimizer
|
plugin = None
|
||||||
model = zero_model_wrapper(model, zero_stage, gemini_config)
|
if args.distplan.startswith("CAI_ZeRO"):
|
||||||
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
|
plugin = LowLevelZeroPlugin(stage=zero_stage,
|
||||||
|
reduce_bucket_size_in_m=12 * 1024 * 1024,
|
||||||
|
overlap_communication=True,
|
||||||
|
verbose=True)
|
||||||
|
elif args.distplan == "CAI_Gemini":
|
||||||
|
plugin = GeminiPlugin(device=get_current_device(),
|
||||||
|
placement_policy=args.placement,
|
||||||
|
pin_memory=True,
|
||||||
|
strict_ddp_mode=args.tp_degree == 1,
|
||||||
|
search_range_mb=128,
|
||||||
|
hidden_dim=model.config.n_embd,
|
||||||
|
gpu_margin_mem_ratio=0.)
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
# build a highly optimized gpu/cpu optimizer
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||||
elif args.distplan.startswith("Pytorch"):
|
elif args.distplan.startswith("Pytorch"):
|
||||||
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
|
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
|
||||||
model = model_builder(args.model_type)(checkpoint=True).cuda()
|
model = model_builder(args.model_type)(checkpoint=True).cuda()
|
||||||
model = DDP(model)
|
plugin = TorchDDPPlugin()
|
||||||
if args.distplan.endswith("DDP"):
|
if args.distplan.endswith("DDP"):
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
elif args.distplan.endswith("ZeRO"):
|
elif args.distplan.endswith("ZeRO"):
|
||||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||||
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
|
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
# wrap your model and optimizer
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||||
|
|
||||||
# model is shared after TP
|
# model is shared after TP
|
||||||
numel = get_model_size(model)
|
numel = get_model_size(model)
|
||||||
|
@ -305,13 +310,7 @@ def main():
|
||||||
fwd_end = time()
|
fwd_end = time()
|
||||||
fwd_time = fwd_end - start
|
fwd_time = fwd_end - start
|
||||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
|
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
if args.distplan.startswith("CAI"):
|
|
||||||
optimizer.backward(loss)
|
|
||||||
elif args.distplan.startswith("Pytorch"):
|
|
||||||
loss.backward()
|
|
||||||
else:
|
|
||||||
raise RuntimeError
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
bwd_end = time()
|
bwd_end = time()
|
||||||
|
|
Loading…
Reference in New Issue