mirror of https://github.com/hpcaitech/ColossalAI
[zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode * [polish] add comments for strict ddp mode * [zero] fix test errorpull/2509/head
parent
c04f183237
commit
2d1a7dfe5f
|
@ -12,6 +12,7 @@ from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ReplicaSpec
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||
|
@ -200,14 +201,18 @@ class ZeroDDP(ColoDDP):
|
|||
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
||||
For more details, see the API reference of ``GeminiManager``.
|
||||
pin_memory (bool): Chunks on CPU Memory use pin-memory.
|
||||
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
|
||||
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
|
||||
Defaults to False.
|
||||
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
||||
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False) -> None:
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False) -> None:
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
|
@ -232,6 +237,9 @@ class ZeroDDP(ColoDDP):
|
|||
for p in param_order.generate():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
if strict_ddp_mode and not p.is_replicate():
|
||||
p.set_dist_spec(ReplicaSpec())
|
||||
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||
continue
|
||||
|
|
|
@ -17,6 +17,7 @@ class GeminiDDP(ZeroDDP):
|
|||
placement_policy: str = "cpu",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
search_range_mb: int = 32,
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_mb: Optional[float] = None,
|
||||
|
@ -54,4 +55,4 @@ class GeminiDDP(ZeroDDP):
|
|||
search_range_mb=search_range_mb,
|
||||
min_chunk_size_mb=min_chunk_size_mb)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)
|
||||
|
|
|
@ -53,6 +53,14 @@ def gpt2_24b(checkpoint=True):
|
|||
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_30b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=8192, num_layers=37, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_40b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def model_builder(model_size: str) -> callable:
|
||||
if model_size == "gpt2_medium":
|
||||
return gpt2_medium
|
||||
|
@ -66,6 +74,10 @@ def model_builder(model_size: str) -> callable:
|
|||
return gpt2_20b
|
||||
elif model_size == "gpt2_24b":
|
||||
return gpt2_24b
|
||||
elif model_size == "gpt2_30b":
|
||||
return gpt2_30b
|
||||
elif model_size == "gpt2_40b":
|
||||
return gpt2_40b
|
||||
else:
|
||||
raise TypeError(f"model_builder {model_size}")
|
||||
|
||||
|
|
|
@ -187,17 +187,18 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
|||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
|
||||
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto", ddp_flag: bool = True):
|
||||
fp16_init_scale = 2**5
|
||||
gpu_margin_mem_ratio_for_auto = 0
|
||||
|
||||
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
|
||||
model = GeminiDDP(model,
|
||||
strict_ddp_mode=ddp_flag,
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
hidden_dim=model.config.n_embd,
|
||||
search_range_mb=64)
|
||||
search_range_mb=128)
|
||||
# configure the const policy
|
||||
if placement_policy == 'const':
|
||||
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
|
||||
|
@ -279,11 +280,12 @@ def main():
|
|||
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
# Tensor Parallelism (TP)
|
||||
# You should notice that v0.1.10 is not compatible with TP degree > 1
|
||||
tensor_parallelize(model, tp_pg)
|
||||
if args.tp_degree > 1:
|
||||
tensor_parallelize(model, tp_pg)
|
||||
|
||||
# build a Gemini model and a highly optimized cpu optimizer
|
||||
# Gemini + ZeRO DP, Note it must be used after TP
|
||||
model, optimizer = build_gemini(model, tp_pg, args.placement)
|
||||
model, optimizer = build_gemini(model, tp_pg, args.placement, args.tp_degree == 1)
|
||||
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
else:
|
||||
|
|
|
@ -93,7 +93,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
|
|||
else:
|
||||
init_device = None
|
||||
|
||||
model = GeminiDDP(model, init_device, placement_policy, True, False, 32)
|
||||
model = GeminiDDP(model, init_device, placement_policy, True, False)
|
||||
# The same as the following 3 lines
|
||||
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
|
|
Loading…
Reference in New Issue