|
|
|
@ -1,3 +1,5 @@
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
from colossalai.gemini.chunk import init_chunk_manager |
|
|
|
@ -14,7 +16,9 @@ class GeminiDDP(ZeroDDP):
|
|
|
|
|
placement_policy: str = "cpu", |
|
|
|
|
pin_memory: bool = False, |
|
|
|
|
force_outputs_fp32: bool = False, |
|
|
|
|
search_range_mb: int = 32) -> None: |
|
|
|
|
search_range_mb: int = 32, |
|
|
|
|
hidden_dim: Optional[int] = None, |
|
|
|
|
min_chunk_size_mb: Optional[float] = None) -> None: |
|
|
|
|
""" |
|
|
|
|
A torch.Module warpper using ZeRO-DP and Genimi. |
|
|
|
|
ZeRO is for parallel. Gemini is for memory management. |
|
|
|
@ -34,7 +38,17 @@ class GeminiDDP(ZeroDDP):
|
|
|
|
|
pin_memory (bool, optional): use pin memory on CPU. Defaults to False. |
|
|
|
|
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. |
|
|
|
|
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. |
|
|
|
|
hidden_dim (int, optional): the hidden dimension of DNN. |
|
|
|
|
Users can provide this argument to speed up searching. |
|
|
|
|
If users do not know this argument before training, it is ok. We will use a default value 1024. |
|
|
|
|
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. |
|
|
|
|
If the aggregate size of parameters is still samller than the minimum chunk size, |
|
|
|
|
all parameters will be compacted into one small chunk. |
|
|
|
|
""" |
|
|
|
|
chunk_manager = init_chunk_manager(model=module, init_device=device, search_range_mb=search_range_mb) |
|
|
|
|
chunk_manager = init_chunk_manager(model=module, |
|
|
|
|
init_device=device, |
|
|
|
|
hidden_dim=hidden_dim, |
|
|
|
|
search_range_mb=search_range_mb, |
|
|
|
|
min_chunk_size_mb=min_chunk_size_mb) |
|
|
|
|
gemini_manager = GeminiManager(placement_policy, chunk_manager, module) |
|
|
|
|
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32) |
|
|
|
|