[gemini] add arguments (#2046)

* [zero] fix testing parameters

* [gemini] add arguments

* add docstrings
pull/2050/head
HELSON 2022-11-30 16:40:13 +08:00 committed by GitHub
parent 6a9158f1fa
commit e37f3db40c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 2 deletions

View File

@ -1,3 +1,5 @@
from typing import Optional
import torch
from colossalai.gemini.chunk import init_chunk_manager
@ -14,7 +16,9 @@ class GeminiDDP(ZeroDDP):
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
search_range_mb: int = 32) -> None:
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: Optional[float] = None) -> None:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
@ -34,7 +38,17 @@ class GeminiDDP(ZeroDDP):
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
"""
chunk_manager = init_chunk_manager(model=module, init_device=device, search_range_mb=search_range_mb)
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb)
gemini_manager = GeminiManager(placement_policy, chunk_manager, module)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)