[example] fix gpt example with 0.1.10 (#2265)

pull/2272/head
HELSON 2023-01-03 13:38:14 +08:00 committed by GitHub
parent 89f048a88a
commit 09c0102fe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 28 deletions

View File

@ -4,7 +4,7 @@ export DISTPAN=${DISTPAN:-"colossalai"}
# The following options only valid when DISTPAN="colossalai"
export GPUNUM=${GPUNUM:-1}
export TPDEGREE=${TPDEGREE:-1}
export PLACEMENT=${PLACEMENT:-"const"}
export PLACEMENT=${PLACEMENT:-"cpu"}
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}

View File

@ -5,18 +5,24 @@ from time import time
import psutil
import torch
import torch.nn as nn
from model_zoo import model_builder
from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
from model_zoo import model_builder
CAI_VERSION = colossalai.__version__
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
# These are added after 0.1.10
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
def parse_args():
@ -62,7 +68,7 @@ def parse_args():
return args
## Parameter Sharding Strategies for Tensor Parallelism
# Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
@ -179,34 +185,52 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__
from colossalai.gemini import ChunkManager, GeminiManager
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
fp16_init_scale = 2**5
gpu_margin_mem_ratio_for_auto = 0
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
placement_policy=placement_policy,
pin_memory=True,
hidden_dim=model.config.n_embd,
search_range_mb=64)
if placememt_policy == 'const':
# configure the const policy
if placement_policy == 'const':
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
# build a highly optimized cpu optimizer
optimizer = GeminiAdamOptimizer(model,
lr=1e-3,
initial_scale=fp16_init_scale,
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
elif version.parse("0.1.9") <= version.parse(CAI_VERSION) <= version.parse("0.1.10"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 1024, filter_exlarge_params=True)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy))
init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer,
model,
initial_scale=fp16_init_scale,
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
raise NotImplemented(f"CAI version {CAI_VERSION} is not supported")
return model, optimizer
def main():
# version check
# this example is supposed to work for versions less than 0.2.0 but greater than 0.1.9
assert version.parse(CAI_VERSION) < version.parse("0.2.0")
assert version.parse(CAI_VERSION) >= version.parse("0.1.9")
set_cpu_maximum_parallelism()
args = parse_args()
@ -239,21 +263,24 @@ def main():
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
# build GPT model
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=default_pg):
model = model_builder(args.model_type)(checkpoint=True)
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=default_pg):
model = model_builder(args.model_type)(checkpoint=True)
else:
with ColoInitContext(device=get_current_device()):
model = model_builder(args.model_type)(checkpoint=True)
pg = default_pg
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# build a Gemini model and a highly optimized cpu optimizer
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement)
model, optimizer = build_gemini(model, pg, args.placement)
# build highly optimized cpu optimizer
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5, gpu_margin_mem_ratio=0.6)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else:
model = model_builder(args.model_type)(checkpoint=True).cuda()
@ -324,8 +351,6 @@ def main():
if n >= WARMUP_STEPS:
tflops_list.append(step_tflops)
logger.info(f"max memory {torch.cuda.max_memory_allocated() / 1024**2} MB", ranks=[0])
tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")