diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index eaf85f2fb..78b6b499e 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -14,7 +14,7 @@ from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.utils import get_current_device -from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 +from colossalai.zero.utils.gemini_hook import GeminiZeROHook from .reducer import Reducer @@ -210,7 +210,7 @@ class ZeroDDP(ColoDDP): self.gemini_manager = gemini_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = ZeROHookV2(gemini_manager) + self.param_op_hook = GeminiZeROHook(gemini_manager) self.fp32_params: List[ColoTensor] = [] self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = {} diff --git a/colossalai/zero/utils/zero_hook_v2.py b/colossalai/zero/utils/gemini_hook.py similarity index 98% rename from colossalai/zero/utils/zero_hook_v2.py rename to colossalai/zero/utils/gemini_hook.py index 584a0fe37..4fbbcf376 100644 --- a/colossalai/zero/utils/zero_hook_v2.py +++ b/colossalai/zero/utils/gemini_hook.py @@ -1,11 +1,13 @@ -import torch -from colossalai.tensor.param_op_hook import ParamOpHook -from colossalai.gemini import TensorState -from enum import Enum -from typing import List from contextlib import contextmanager +from enum import Enum from functools import partial +from typing import List + +import torch + +from colossalai.gemini import TensorState from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.tensor.param_op_hook import ParamOpHook class TrainingPhase(Enum): @@ -13,7 +15,7 @@ class TrainingPhase(Enum): BACKWARD = 1 -class ZeROHookV2(ParamOpHook): +class GeminiZeROHook(ParamOpHook): def __init__(self, gemini_manager: GeminiManager) -> None: super().__init__() diff --git a/docs/colossalai/colossalai.zero.utils.rst b/docs/colossalai/colossalai.zero.utils.rst index 15cf4d70d..50ee9071e 100644 --- a/docs/colossalai/colossalai.zero.utils.rst +++ b/docs/colossalai/colossalai.zero.utils.rst @@ -9,4 +9,4 @@ colossalai.zero.utils :maxdepth: 2 colossalai.zero.utils.zero_hook - colossalai.zero.utils.zero_hook_v2 + colossalai.zero.utils.gemini_hook diff --git a/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst b/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst index 6c9af62f1..e6d6673af 100644 --- a/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst +++ b/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst @@ -1,5 +1,5 @@ colossalai.zero.utils.zero\_hook\_v2 ==================================== -.. automodule:: colossalai.zero.utils.zero_hook_v2 +.. automodule:: colossalai.zero.utils.gemini_hook :members: