Browse Source

[gemini] add arguments (#2046)

* [zero] fix testing parameters

* [gemini] add arguments

* add docstrings
pull/2050/head
HELSON 2 years ago committed by GitHub
parent
commit
e37f3db40c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 18
      colossalai/nn/parallel/gemini_parallel.py

18
colossalai/nn/parallel/gemini_parallel.py

@ -1,3 +1,5 @@
from typing import Optional
import torch import torch
from colossalai.gemini.chunk import init_chunk_manager from colossalai.gemini.chunk import init_chunk_manager
@ -14,7 +16,9 @@ class GeminiDDP(ZeroDDP):
placement_policy: str = "cpu", placement_policy: str = "cpu",
pin_memory: bool = False, pin_memory: bool = False,
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
search_range_mb: int = 32) -> None: search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: Optional[float] = None) -> None:
""" """
A torch.Module warpper using ZeRO-DP and Genimi. A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management. ZeRO is for parallel. Gemini is for memory management.
@ -34,7 +38,17 @@ class GeminiDDP(ZeroDDP):
pin_memory (bool, optional): use pin memory on CPU. Defaults to False. pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
""" """
chunk_manager = init_chunk_manager(model=module, init_device=device, search_range_mb=search_range_mb) chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb)
gemini_manager = GeminiManager(placement_policy, chunk_manager, module) gemini_manager = GeminiManager(placement_policy, chunk_manager, module)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)

Loading…
Cancel
Save