import torch from colossalai.gemini.chunk import init_chunk_manager from colossalai.gemini.gemini_mgr import GeminiManager from .data_parallel import ZeroDDP class GeminiDDP(ZeroDDP): def __init__(self, module: torch.nn.Module, device: torch.device, placement_policy: str = "cpu", pin_memory: bool = False, force_outputs_fp32: bool = False, search_range_mb: int = 32) -> None: """ A torch.Module warpper using ZeRODPP and Genimi. ZeRO is for parallel. Gemini is for memory management. Example: model is initialized under the context of ColoInitContext >>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda") >>> logits = model(x) >>> loss = criterion(logits, labels) >>> model.backward(loss) Args: module (torch.nn.Module): the model to be wrapped. device (torch.device): device to place the model. placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". pin_memory (bool, optional): use pin memory on CPU. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. """ chunk_manager = init_chunk_manager(model=module, init_device=device, search_range_mb=search_range_mb) gemini_manager = GeminiManager(placement_policy, chunk_manager, module) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)