mirror of https://github.com/hpcaitech/ColossalAI
[gemini] update the gpt example (#2527)
parent
ecbad93b65
commit
66dfcf5281
|
@ -32,16 +32,19 @@ def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Opt
|
||||||
>>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
|
>>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
|
||||||
>>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)
|
>>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)
|
||||||
"""
|
"""
|
||||||
setattr(model, "_colo_zero_stage", zero_stage)
|
|
||||||
assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3"
|
assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3"
|
||||||
|
|
||||||
if gemini_config is None:
|
if gemini_config is None:
|
||||||
gemini_config = dict()
|
gemini_config = dict()
|
||||||
|
|
||||||
if zero_stage in [1, 2]:
|
if zero_stage in [1, 2]:
|
||||||
return model
|
wrapped_model = model
|
||||||
else:
|
else:
|
||||||
return GeminiDDP(model, **gemini_config)
|
wrapped_model = GeminiDDP(model, **gemini_config)
|
||||||
|
|
||||||
|
setattr(wrapped_model, "_colo_zero_stage", zero_stage)
|
||||||
|
|
||||||
|
return wrapped_model
|
||||||
|
|
||||||
|
|
||||||
def zero_optim_wrapper(model: nn.Module,
|
def zero_optim_wrapper(model: nn.Module,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
for MODEL_TYPE in "gpt2_medium"; do
|
for MODEL_TYPE in "gpt2_medium"; do
|
||||||
for DISTPLAN in "colossalai"; do
|
for DISTPLAN in "CAI_Gemini"; do
|
||||||
for BATCH_SIZE in 16; do
|
for BATCH_SIZE in 16; do
|
||||||
for GPUNUM in 1 2 4 8; do
|
for GPUNUM in 1 2 4 8; do
|
||||||
for TPDEGREE in 1 2 4 8; do
|
for TPDEGREE in 1 2 4 8; do
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
set -x
|
set -x
|
||||||
# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
|
# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]
|
||||||
export DISTPLAN=${DISTPLAN:-"colossalai"}
|
export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
|
||||||
|
|
||||||
# The following options only valid when DISTPLAN="colossalai"
|
# The following options only valid when DISTPLAN="colossalai"
|
||||||
export GPUNUM=${GPUNUM:-1}
|
export GPUNUM=${GPUNUM:-1}
|
||||||
|
@ -12,6 +12,12 @@ export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
||||||
export TRAIN_STEP=${TRAIN_STEP:-10}
|
export TRAIN_STEP=${TRAIN_STEP:-10}
|
||||||
# export PYTHONPATH=$PWD:$PYTHONPATH
|
# export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
|
||||||
|
if [ ${USE_SHARD_INIT} = "True" ]; then
|
||||||
|
USE_SHARD_INIT="--shardinit"
|
||||||
|
else
|
||||||
|
USE_SHARD_INIT=""
|
||||||
|
fi
|
||||||
|
|
||||||
mkdir -p gemini_logs
|
mkdir -p gemini_logs
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
|
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
|
||||||
|
@ -19,7 +25,7 @@ torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
|
||||||
--model_type=${MODEL_TYPE} \
|
--model_type=${MODEL_TYPE} \
|
||||||
--batch_size=${BATCH_SIZE} \
|
--batch_size=${BATCH_SIZE} \
|
||||||
--placement=${PLACEMENT} \
|
--placement=${PLACEMENT} \
|
||||||
--shardinit=${USE_SHARD_INIT} \
|
${USE_SHARD_INIT} \
|
||||||
--distplan=${DISTPLAN} \
|
--distplan=${DISTPLAN} \
|
||||||
--train_step=${TRAIN_STEP} \
|
--train_step=${TRAIN_STEP} \
|
||||||
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
|
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
|
||||||
|
|
|
@ -12,26 +12,21 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
|
||||||
CAI_VERSION = colossalai.__version__
|
CAI_VERSION = colossalai.__version__
|
||||||
|
|
||||||
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
|
|
||||||
# These are added after 0.1.10
|
|
||||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
|
||||||
from colossalai.nn.parallel import GeminiDDP
|
|
||||||
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = colossalai.get_default_parser()
|
parser = colossalai.get_default_parser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--distplan",
|
"--distplan",
|
||||||
type=str,
|
type=str,
|
||||||
default='colossalai',
|
default='CAI_Gemini',
|
||||||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -48,8 +43,7 @@ def parse_args():
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--shardinit",
|
"--shardinit",
|
||||||
type=bool,
|
action='store_true',
|
||||||
default=False,
|
|
||||||
help=
|
help=
|
||||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||||
)
|
)
|
||||||
|
@ -186,57 +180,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||||
param.visited = True
|
param.visited = True
|
||||||
|
|
||||||
|
|
||||||
# Gemini + ZeRO DDP
|
|
||||||
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto", ddp_flag: bool = True):
|
|
||||||
fp16_init_scale = 2**5
|
|
||||||
gpu_margin_mem_ratio_for_auto = 0
|
|
||||||
|
|
||||||
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
|
|
||||||
model = GeminiDDP(model,
|
|
||||||
strict_ddp_mode=ddp_flag,
|
|
||||||
device=get_current_device(),
|
|
||||||
placement_policy=placement_policy,
|
|
||||||
pin_memory=True,
|
|
||||||
hidden_dim=model.config.n_embd,
|
|
||||||
search_range_mb=128)
|
|
||||||
# configure the const policy
|
|
||||||
if placement_policy == 'const':
|
|
||||||
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
|
|
||||||
# build a highly optimized cpu optimizer
|
|
||||||
optimizer = GeminiAdamOptimizer(model,
|
|
||||||
lr=1e-3,
|
|
||||||
initial_scale=fp16_init_scale,
|
|
||||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
|
|
||||||
elif version.parse("0.1.9") <= version.parse(CAI_VERSION) <= version.parse("0.1.10"):
|
|
||||||
from colossalai.gemini import ChunkManager, GeminiManager
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
from colossalai.zero import ZeroOptimizer
|
|
||||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 1024, filter_exlarge_params=True)
|
|
||||||
chunk_manager = ChunkManager(chunk_size,
|
|
||||||
pg,
|
|
||||||
enable_distributed_storage=True,
|
|
||||||
init_device=GeminiManager.get_default_device(placement_policy))
|
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
|
||||||
model = ZeroDDP(model, gemini_manager)
|
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
|
||||||
optimizer = ZeroOptimizer(optimizer,
|
|
||||||
model,
|
|
||||||
initial_scale=fp16_init_scale,
|
|
||||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
|
|
||||||
else:
|
|
||||||
raise NotImplemented(f"CAI version {CAI_VERSION} is not supported")
|
|
||||||
return model, optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# version check
|
# version check
|
||||||
# this example is supposed to work for versions greater than 0.1.9
|
# this example is supposed to work for versions greater than 0.2.0
|
||||||
assert version.parse(CAI_VERSION) >= version.parse("0.1.9")
|
assert version.parse(CAI_VERSION) >= version.parse("0.2.0")
|
||||||
|
|
||||||
set_cpu_maximum_parallelism()
|
set_cpu_maximum_parallelism()
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
# if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
||||||
|
if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
|
||||||
raise TypeError(f"{args.distplan} is error")
|
raise TypeError(f"{args.distplan} is error")
|
||||||
|
|
||||||
# batch size per DP degree
|
# batch size per DP degree
|
||||||
|
@ -260,22 +213,21 @@ def main():
|
||||||
criterion = GPTLMLoss()
|
criterion = GPTLMLoss()
|
||||||
|
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
if args.distplan == "colossalai":
|
if args.distplan.startswith("CAI"):
|
||||||
# all param must use the same process group.
|
# all param must use the same process group.
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
|
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
|
||||||
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
||||||
|
|
||||||
|
if args.shardinit and args.distplan != "CAI_Gemini":
|
||||||
|
raise RuntimeError("You can only use shardinit with CAI_Gemini")
|
||||||
|
|
||||||
# build GPT model
|
# build GPT model
|
||||||
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
|
with ColoInitContext(device=get_current_device(),
|
||||||
with ColoInitContext(device=get_current_device(),
|
dtype=torch.half,
|
||||||
dtype=torch.half,
|
default_dist_spec=default_dist_spec,
|
||||||
default_dist_spec=default_dist_spec,
|
default_pg=shard_pg):
|
||||||
default_pg=shard_pg):
|
model = model_builder(args.model_type)(checkpoint=True)
|
||||||
model = model_builder(args.model_type)(checkpoint=True)
|
|
||||||
else:
|
|
||||||
with ColoInitContext(device=get_current_device()):
|
|
||||||
model = model_builder(args.model_type)(checkpoint=True)
|
|
||||||
|
|
||||||
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||||
# Tensor Parallelism (TP)
|
# Tensor Parallelism (TP)
|
||||||
|
@ -283,34 +235,49 @@ def main():
|
||||||
if args.tp_degree > 1:
|
if args.tp_degree > 1:
|
||||||
tensor_parallelize(model, tp_pg)
|
tensor_parallelize(model, tp_pg)
|
||||||
|
|
||||||
# build a Gemini model and a highly optimized cpu optimizer
|
# asign running configurations
|
||||||
# Gemini + ZeRO DP, Note it must be used after TP
|
gemini_config = None
|
||||||
model, optimizer = build_gemini(model, tp_pg, args.placement, args.tp_degree == 1)
|
if args.distplan.startswith("CAI_ZeRO"):
|
||||||
|
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
|
||||||
|
elif args.distplan == "CAI_Gemini":
|
||||||
|
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
|
||||||
|
device=get_current_device(),
|
||||||
|
placement_policy=args.placement,
|
||||||
|
pin_memory=True,
|
||||||
|
hidden_dim=model.config.n_embd,
|
||||||
|
search_range_mb=128)
|
||||||
|
optim_config = dict(gpu_margin_mem_ratio=0.)
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
# build a highly optimized gpu/cpu optimizer
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
if args.distplan == "CAI_ZeRO1":
|
||||||
|
zero_stage = 1
|
||||||
|
elif args.distplan == "CAI_ZeRO2":
|
||||||
|
zero_stage = 2
|
||||||
|
elif args.distplan == "CAI_Gemini":
|
||||||
|
zero_stage = 3
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
# wrap your model and optimizer
|
||||||
|
model = zero_model_wrapper(model, zero_stage, gemini_config)
|
||||||
|
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
|
||||||
|
|
||||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||||
else:
|
elif args.distplan.startswith("Pytorch"):
|
||||||
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
|
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
|
||||||
model = model_builder(args.model_type)(checkpoint=True).cuda()
|
model = model_builder(args.model_type)(checkpoint=True).cuda()
|
||||||
|
|
||||||
if args.distplan.startswith("torch"):
|
|
||||||
model = DDP(model)
|
model = DDP(model)
|
||||||
if args.distplan.endswith("ddp"):
|
if args.distplan.endswith("DDP"):
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
elif args.distplan.endswith("zero"):
|
elif args.distplan.endswith("ZeRO"):
|
||||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||||
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
|
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
|
||||||
elif args.distplan.startswith("zero"):
|
else:
|
||||||
model = model.half()
|
raise RuntimeError
|
||||||
partition_flag = (args.distplan == "zero2")
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
|
||||||
|
|
||||||
optimizer = LowLevelZeroOptimizer(
|
|
||||||
optimizer,
|
|
||||||
reduce_bucket_size=12 * 1024 * 1024,
|
|
||||||
overlap_communication=True,
|
|
||||||
partition_grad=partition_flag,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# model is shared after TP
|
# model is shared after TP
|
||||||
numel = get_model_size(model)
|
numel = get_model_size(model)
|
||||||
|
@ -338,17 +305,18 @@ def main():
|
||||||
fwd_time = fwd_end - start
|
fwd_time = fwd_end - start
|
||||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
|
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||||
|
|
||||||
if args.distplan in ["colossalai", "zero1", "zero2"]:
|
if args.distplan.startswith("CAI"):
|
||||||
optimizer.backward(loss)
|
optimizer.backward(loss)
|
||||||
elif args.distplan in ["torch_ddp", "torch_zero"]:
|
elif args.distplan.startswith("Pytorch"):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
bwd_end = time()
|
bwd_end = time()
|
||||||
bwd_time = bwd_end - fwd_end
|
bwd_time = bwd_end - fwd_end
|
||||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
|
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
|
||||||
|
|
||||||
if args.distplan in ["zero1", "zero2"]:
|
|
||||||
optimizer.sync_grad()
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
optim_time = time() - bwd_end
|
optim_time = time() - bwd_end
|
||||||
|
|
Loading…
Reference in New Issue