diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py index 6cc188b4b..9f13cece2 100644 --- a/colossalai/nn/parallel/gemini_parallel.py +++ b/colossalai/nn/parallel/gemini_parallel.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from colossalai.gemini.chunk import init_chunk_manager @@ -14,7 +16,9 @@ class GeminiDDP(ZeroDDP): placement_policy: str = "cpu", pin_memory: bool = False, force_outputs_fp32: bool = False, - search_range_mb: int = 32) -> None: + search_range_mb: int = 32, + hidden_dim: Optional[int] = None, + min_chunk_size_mb: Optional[float] = None) -> None: """ A torch.Module warpper using ZeRO-DP and Genimi. ZeRO is for parallel. Gemini is for memory management. @@ -34,7 +38,17 @@ class GeminiDDP(ZeroDDP): pin_memory (bool, optional): use pin memory on CPU. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. + hidden_dim (int, optional): the hidden dimension of DNN. + Users can provide this argument to speed up searching. + If users do not know this argument before training, it is ok. We will use a default value 1024. + min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. + If the aggregate size of parameters is still samller than the minimum chunk size, + all parameters will be compacted into one small chunk. """ - chunk_manager = init_chunk_manager(model=module, init_device=device, search_range_mb=search_range_mb) + chunk_manager = init_chunk_manager(model=module, + init_device=device, + hidden_dim=hidden_dim, + search_range_mb=search_range_mb, + min_chunk_size_mb=min_chunk_size_mb) gemini_manager = GeminiManager(placement_policy, chunk_manager, module) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)