[gemini] update the gpt example (#2527)

pull/2532/head
HELSON 2023-01-30 17:58:05 +08:00 committed by GitHub
parent ecbad93b65
commit 66dfcf5281
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 98 deletions

View File

@ -32,16 +32,19 @@ def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Opt
>>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto') >>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
>>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict) >>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)
""" """
setattr(model, "_colo_zero_stage", zero_stage)
assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3" assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3"
if gemini_config is None: if gemini_config is None:
gemini_config = dict() gemini_config = dict()
if zero_stage in [1, 2]: if zero_stage in [1, 2]:
return model wrapped_model = model
else: else:
return GeminiDDP(model, **gemini_config) wrapped_model = GeminiDDP(model, **gemini_config)
setattr(wrapped_model, "_colo_zero_stage", zero_stage)
return wrapped_model
def zero_optim_wrapper(model: nn.Module, def zero_optim_wrapper(model: nn.Module,

View File

@ -1,5 +1,5 @@
for MODEL_TYPE in "gpt2_medium"; do for MODEL_TYPE in "gpt2_medium"; do
for DISTPLAN in "colossalai"; do for DISTPLAN in "CAI_Gemini"; do
for BATCH_SIZE in 16; do for BATCH_SIZE in 16; do
for GPUNUM in 1 2 4 8; do for GPUNUM in 1 2 4 8; do
for TPDEGREE in 1 2 4 8; do for TPDEGREE in 1 2 4 8; do

View File

@ -1,6 +1,6 @@
set -x set -x
# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"] # distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]
export DISTPLAN=${DISTPLAN:-"colossalai"} export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
# The following options only valid when DISTPLAN="colossalai" # The following options only valid when DISTPLAN="colossalai"
export GPUNUM=${GPUNUM:-1} export GPUNUM=${GPUNUM:-1}
@ -12,6 +12,12 @@ export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export TRAIN_STEP=${TRAIN_STEP:-10} export TRAIN_STEP=${TRAIN_STEP:-10}
# export PYTHONPATH=$PWD:$PYTHONPATH # export PYTHONPATH=$PWD:$PYTHONPATH
if [ ${USE_SHARD_INIT} = "True" ]; then
USE_SHARD_INIT="--shardinit"
else
USE_SHARD_INIT=""
fi
mkdir -p gemini_logs mkdir -p gemini_logs
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \ torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
@ -19,7 +25,7 @@ torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
--model_type=${MODEL_TYPE} \ --model_type=${MODEL_TYPE} \
--batch_size=${BATCH_SIZE} \ --batch_size=${BATCH_SIZE} \
--placement=${PLACEMENT} \ --placement=${PLACEMENT} \
--shardinit=${USE_SHARD_INIT} \ ${USE_SHARD_INIT} \
--distplan=${DISTPLAN} \ --distplan=${DISTPLAN} \
--train_step=${TRAIN_STEP} \ --train_step=${TRAIN_STEP} \
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log 2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log

View File

@ -12,26 +12,21 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
CAI_VERSION = colossalai.__version__ CAI_VERSION = colossalai.__version__
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
# These are added after 0.1.10
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
def parse_args(): def parse_args():
parser = colossalai.get_default_parser() parser = colossalai.get_default_parser()
parser.add_argument( parser.add_argument(
"--distplan", "--distplan",
type=str, type=str,
default='colossalai', default='CAI_Gemini',
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
) )
parser.add_argument( parser.add_argument(
@ -48,8 +43,7 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--shardinit", "--shardinit",
type=bool, action='store_true',
default=False,
help= help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
) )
@ -186,57 +180,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
param.visited = True param.visited = True
# Gemini + ZeRO DDP
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto", ddp_flag: bool = True):
fp16_init_scale = 2**5
gpu_margin_mem_ratio_for_auto = 0
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
model = GeminiDDP(model,
strict_ddp_mode=ddp_flag,
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
hidden_dim=model.config.n_embd,
search_range_mb=128)
# configure the const policy
if placement_policy == 'const':
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
# build a highly optimized cpu optimizer
optimizer = GeminiAdamOptimizer(model,
lr=1e-3,
initial_scale=fp16_init_scale,
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
elif version.parse("0.1.9") <= version.parse(CAI_VERSION) <= version.parse("0.1.10"):
from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 1024, filter_exlarge_params=True)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer,
model,
initial_scale=fp16_init_scale,
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
else:
raise NotImplemented(f"CAI version {CAI_VERSION} is not supported")
return model, optimizer
def main(): def main():
# version check # version check
# this example is supposed to work for versions greater than 0.1.9 # this example is supposed to work for versions greater than 0.2.0
assert version.parse(CAI_VERSION) >= version.parse("0.1.9") assert version.parse(CAI_VERSION) >= version.parse("0.2.0")
set_cpu_maximum_parallelism() set_cpu_maximum_parallelism()
args = parse_args() args = parse_args()
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]: # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
raise TypeError(f"{args.distplan} is error") raise TypeError(f"{args.distplan} is error")
# batch size per DP degree # batch size per DP degree
@ -260,22 +213,21 @@ def main():
criterion = GPTLMLoss() criterion = GPTLMLoss()
torch.manual_seed(123) torch.manual_seed(123)
if args.distplan == "colossalai": if args.distplan.startswith("CAI"):
# all param must use the same process group. # all param must use the same process group.
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
if args.shardinit and args.distplan != "CAI_Gemini":
raise RuntimeError("You can only use shardinit with CAI_Gemini")
# build GPT model # build GPT model
if version.parse(CAI_VERSION) > version.parse("0.1.10"): with ColoInitContext(device=get_current_device(),
with ColoInitContext(device=get_current_device(), dtype=torch.half,
dtype=torch.half, default_dist_spec=default_dist_spec,
default_dist_spec=default_dist_spec, default_pg=shard_pg):
default_pg=shard_pg): model = model_builder(args.model_type)(checkpoint=True)
model = model_builder(args.model_type)(checkpoint=True)
else:
with ColoInitContext(device=get_current_device()):
model = model_builder(args.model_type)(checkpoint=True)
tp_pg = ProcessGroup(tp_degree=args.tp_degree) tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP) # Tensor Parallelism (TP)
@ -283,34 +235,49 @@ def main():
if args.tp_degree > 1: if args.tp_degree > 1:
tensor_parallelize(model, tp_pg) tensor_parallelize(model, tp_pg)
# build a Gemini model and a highly optimized cpu optimizer # asign running configurations
# Gemini + ZeRO DP, Note it must be used after TP gemini_config = None
model, optimizer = build_gemini(model, tp_pg, args.placement, args.tp_degree == 1) if args.distplan.startswith("CAI_ZeRO"):
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
elif args.distplan == "CAI_Gemini":
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
hidden_dim=model.config.n_embd,
search_range_mb=128)
optim_config = dict(gpu_margin_mem_ratio=0.)
else:
raise RuntimeError
# build a highly optimized gpu/cpu optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
if args.distplan == "CAI_ZeRO1":
zero_stage = 1
elif args.distplan == "CAI_ZeRO2":
zero_stage = 2
elif args.distplan == "CAI_Gemini":
zero_stage = 3
else:
raise RuntimeError
# wrap your model and optimizer
model = zero_model_wrapper(model, zero_stage, gemini_config)
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else: elif args.distplan.startswith("Pytorch"):
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
model = model_builder(args.model_type)(checkpoint=True).cuda() model = model_builder(args.model_type)(checkpoint=True).cuda()
if args.distplan.startswith("torch"):
model = DDP(model) model = DDP(model)
if args.distplan.endswith("ddp"): if args.distplan.endswith("DDP"):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
elif args.distplan.endswith("zero"): elif args.distplan.endswith("ZeRO"):
from torch.distributed.optim import ZeroRedundancyOptimizer from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01) optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
elif args.distplan.startswith("zero"): else:
model = model.half() raise RuntimeError
partition_flag = (args.distplan == "zero2")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = LowLevelZeroOptimizer(
optimizer,
reduce_bucket_size=12 * 1024 * 1024,
overlap_communication=True,
partition_grad=partition_flag,
verbose=True,
)
# model is shared after TP # model is shared after TP
numel = get_model_size(model) numel = get_model_size(model)
@ -338,17 +305,18 @@ def main():
fwd_time = fwd_end - start fwd_time = fwd_end - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
if args.distplan in ["colossalai", "zero1", "zero2"]: if args.distplan.startswith("CAI"):
optimizer.backward(loss) optimizer.backward(loss)
elif args.distplan in ["torch_ddp", "torch_zero"]: elif args.distplan.startswith("Pytorch"):
loss.backward() loss.backward()
else:
raise RuntimeError
torch.cuda.synchronize() torch.cuda.synchronize()
bwd_end = time() bwd_end = time()
bwd_time = bwd_end - fwd_end bwd_time = bwd_end - fwd_end
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0]) logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
if args.distplan in ["zero1", "zero2"]:
optimizer.sync_grad()
optimizer.step() optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()
optim_time = time() - bwd_end optim_time = time() - bwd_end