mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] ZeROHookV2 -> GeminiZeROHook (#1972)
parent
f8a7148dec
commit
cc0ed7cf33
|
@ -14,7 +14,7 @@ from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
|||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
||||
|
||||
from .reducer import Reducer
|
||||
|
||||
|
@ -210,7 +210,7 @@ class ZeroDDP(ColoDDP):
|
|||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
self.param_op_hook = GeminiZeROHook(gemini_manager)
|
||||
self.fp32_params: List[ColoTensor] = []
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import torch
|
||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||
from colossalai.gemini import TensorState
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini import TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
|
@ -13,7 +15,7 @@ class TrainingPhase(Enum):
|
|||
BACKWARD = 1
|
||||
|
||||
|
||||
class ZeROHookV2(ParamOpHook):
|
||||
class GeminiZeROHook(ParamOpHook):
|
||||
|
||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__()
|
|
@ -9,4 +9,4 @@ colossalai.zero.utils
|
|||
:maxdepth: 2
|
||||
|
||||
colossalai.zero.utils.zero_hook
|
||||
colossalai.zero.utils.zero_hook_v2
|
||||
colossalai.zero.utils.gemini_hook
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
colossalai.zero.utils.zero\_hook\_v2
|
||||
====================================
|
||||
|
||||
.. automodule:: colossalai.zero.utils.zero_hook_v2
|
||||
.. automodule:: colossalai.zero.utils.gemini_hook
|
||||
:members:
|
||||
|
|
Loading…
Reference in New Issue