[example] fix gpt example with 0.1.10 (#2265)

pull/2272/head
HELSON 2023-01-03 13:38:14 +08:00 committed by GitHub
parent 89f048a88a
commit 09c0102fe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 28 deletions

View File

@ -4,7 +4,7 @@ export DISTPAN=${DISTPAN:-"colossalai"}
# The following options only valid when DISTPAN="colossalai" # The following options only valid when DISTPAN="colossalai"
export GPUNUM=${GPUNUM:-1} export GPUNUM=${GPUNUM:-1}
export TPDEGREE=${TPDEGREE:-1} export TPDEGREE=${TPDEGREE:-1}
export PLACEMENT=${PLACEMENT:-"const"} export PLACEMENT=${PLACEMENT:-"cpu"}
export USE_SHARD_INIT=${USE_SHARD_INIT:-False} export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
export BATCH_SIZE=${BATCH_SIZE:-16} export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}

View File

@ -5,18 +5,24 @@ from time import time
import psutil import psutil
import torch import torch
import torch.nn as nn import torch.nn as nn
from model_zoo import model_builder
from packaging import version from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
from model_zoo import model_builder CAI_VERSION = colossalai.__version__
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
# These are added after 0.1.10
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
def parse_args(): def parse_args():
@ -62,7 +68,7 @@ def parse_args():
return args return args
## Parameter Sharding Strategies for Tensor Parallelism # Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec) param.set_tensor_spec(*spec)
@ -179,34 +185,52 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
# Gemini + ZeRO DDP # Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
cai_version = colossalai.__version__ fp16_init_scale = 2**5
from colossalai.gemini import ChunkManager, GeminiManager gpu_margin_mem_ratio_for_auto = 0
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP if version.parse(CAI_VERSION) > version.parse("0.1.10"):
model = GeminiDDP(model, model = GeminiDDP(model,
device=get_current_device(), device=get_current_device(),
placement_policy=placememt_policy, placement_policy=placement_policy,
pin_memory=True, pin_memory=True,
hidden_dim=model.config.n_embd, hidden_dim=model.config.n_embd,
search_range_mb=64) search_range_mb=64)
if placememt_policy == 'const': # configure the const policy
if placement_policy == 'const':
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024) model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): # build a highly optimized cpu optimizer
optimizer = GeminiAdamOptimizer(model,
lr=1e-3,
initial_scale=fp16_init_scale,
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
elif version.parse("0.1.9") <= version.parse(CAI_VERSION) <= version.parse("0.1.10"):
from colossalai.gemini import ChunkManager, GeminiManager from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) from colossalai.nn.optimizer import HybridAdam
gemini_manager = GeminiManager(placememt_policy, chunk_manager) from colossalai.zero import ZeroOptimizer
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 1024, filter_exlarge_params=True)
chunk_manager = ChunkManager(chunk_size, chunk_manager = ChunkManager(chunk_size,
pg, pg,
enable_distributed_storage=True, enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy)) init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager) model = ZeroDDP(model, gemini_manager)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer,
model,
initial_scale=fp16_init_scale,
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
else: else:
raise NotImplemented(f"CAI version {cai_version} is not supported") raise NotImplemented(f"CAI version {CAI_VERSION} is not supported")
return model return model, optimizer
def main(): def main():
# version check
# this example is supposed to work for versions less than 0.2.0 but greater than 0.1.9
assert version.parse(CAI_VERSION) < version.parse("0.2.0")
assert version.parse(CAI_VERSION) >= version.parse("0.1.9")
set_cpu_maximum_parallelism() set_cpu_maximum_parallelism()
args = parse_args() args = parse_args()
@ -239,21 +263,24 @@ def main():
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
# build GPT model # build GPT model
with ColoInitContext(device=get_current_device(), if version.parse(CAI_VERSION) > version.parse("0.1.10"):
dtype=torch.half, with ColoInitContext(device=get_current_device(),
default_dist_spec=default_dist_spec, dtype=torch.half,
default_pg=default_pg): default_dist_spec=default_dist_spec,
model = model_builder(args.model_type)(checkpoint=True) default_pg=default_pg):
model = model_builder(args.model_type)(checkpoint=True)
else:
with ColoInitContext(device=get_current_device()):
model = model_builder(args.model_type)(checkpoint=True)
pg = default_pg pg = default_pg
# Tensor Parallelism (TP) # Tensor Parallelism (TP)
tensor_parallelize(model, pg) tensor_parallelize(model, pg)
# build a Gemini model and a highly optimized cpu optimizer
# Gemini + ZeRO DP, Note it must be used after TP # Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement) model, optimizer = build_gemini(model, pg, args.placement)
# build highly optimized cpu optimizer
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5, gpu_margin_mem_ratio=0.6)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else: else:
model = model_builder(args.model_type)(checkpoint=True).cuda() model = model_builder(args.model_type)(checkpoint=True).cuda()
@ -324,8 +351,6 @@ def main():
if n >= WARMUP_STEPS: if n >= WARMUP_STEPS:
tflops_list.append(step_tflops) tflops_list.append(step_tflops)
logger.info(f"max memory {torch.cuda.max_memory_allocated() / 1024**2} MB", ranks=[0])
tflops_list.sort() tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")