diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 28a10c4b6..a742946f4 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -12,6 +12,7 @@ from colossalai.gemini.memory_tracer import OrderedParamGenerator from colossalai.logging import get_dist_logger from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.tensor import ReplicaSpec from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import get_current_device, is_ddp_ignored @@ -200,14 +201,18 @@ class ZeroDDP(ColoDDP): gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space. For more details, see the API reference of ``GeminiManager``. pin_memory (bool): Chunks on CPU Memory use pin-memory. - force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False. + force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. + Defaults to False. + strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated. + Defaults to False. Users can set it to True, when they clearly know that they only need DDP. """ def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager, pin_memory: bool = False, - force_outputs_fp32: bool = False) -> None: + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False) -> None: super().__init__(module, process_group=ColoProcessGroup()) self.gemini_manager = gemini_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager @@ -232,6 +237,9 @@ class ZeroDDP(ColoDDP): for p in param_order.generate(): assert isinstance(p, ColoParameter) + if strict_ddp_mode and not p.is_replicate(): + p.set_dist_spec(ReplicaSpec()) + if is_ddp_ignored(p): p.data = p.data.to(device=get_current_device(), dtype=torch.float16) continue diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py index cd5ef424a..868a3960f 100644 --- a/colossalai/nn/parallel/gemini_parallel.py +++ b/colossalai/nn/parallel/gemini_parallel.py @@ -17,6 +17,7 @@ class GeminiDDP(ZeroDDP): placement_policy: str = "cpu", pin_memory: bool = False, force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, search_range_mb: int = 32, hidden_dim: Optional[int] = None, min_chunk_size_mb: Optional[float] = None, @@ -54,4 +55,4 @@ class GeminiDDP(ZeroDDP): search_range_mb=search_range_mb, min_chunk_size_mb=min_chunk_size_mb) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) - super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32) + super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) diff --git a/examples/language/gpt/gemini/commons/model_zoo.py b/examples/language/gpt/gemini/commons/model_zoo.py index c31b3fa6d..65124d9e4 100644 --- a/examples/language/gpt/gemini/commons/model_zoo.py +++ b/examples/language/gpt/gemini/commons/model_zoo.py @@ -53,6 +53,14 @@ def gpt2_24b(checkpoint=True): return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) +def gpt2_30b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=37, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_40b(checkpoint=True): + return GPTLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16, checkpoint=checkpoint) + + def model_builder(model_size: str) -> callable: if model_size == "gpt2_medium": return gpt2_medium @@ -66,6 +74,10 @@ def model_builder(model_size: str) -> callable: return gpt2_20b elif model_size == "gpt2_24b": return gpt2_24b + elif model_size == "gpt2_30b": + return gpt2_30b + elif model_size == "gpt2_40b": + return gpt2_40b else: raise TypeError(f"model_builder {model_size}") diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 713de6f9f..285706596 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -187,17 +187,18 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): # Gemini + ZeRO DDP -def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): +def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto", ddp_flag: bool = True): fp16_init_scale = 2**5 gpu_margin_mem_ratio_for_auto = 0 if version.parse(CAI_VERSION) > version.parse("0.1.10"): model = GeminiDDP(model, + strict_ddp_mode=ddp_flag, device=get_current_device(), placement_policy=placement_policy, pin_memory=True, hidden_dim=model.config.n_embd, - search_range_mb=64) + search_range_mb=128) # configure the const policy if placement_policy == 'const': model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024) @@ -279,11 +280,12 @@ def main(): tp_pg = ProcessGroup(tp_degree=args.tp_degree) # Tensor Parallelism (TP) # You should notice that v0.1.10 is not compatible with TP degree > 1 - tensor_parallelize(model, tp_pg) + if args.tp_degree > 1: + tensor_parallelize(model, tp_pg) # build a Gemini model and a highly optimized cpu optimizer # Gemini + ZeRO DP, Note it must be used after TP - model, optimizer = build_gemini(model, tp_pg, args.placement) + model, optimizer = build_gemini(model, tp_pg, args.placement, args.tp_degree == 1) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) else: diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index 7e611e8a1..83645bc6e 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -93,7 +93,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None): else: init_device = None - model = GeminiDDP(model, init_device, placement_policy, True, False, 32) + model = GeminiDDP(model, init_device, placement_policy, True, False) # The same as the following 3 lines # chunk_manager = ChunkManager(config_dict, init_device=init_device) # gemini_manager = GeminiManager(placement_policy, chunk_manager)