[Gemini] ZeROHookV2 -> GeminiZeROHook (#1972)

pull/1974/head
Jiarui Fang 2022-11-17 14:43:49 +08:00 committed by GitHub
parent f8a7148dec
commit cc0ed7cf33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 10 deletions

View File

@ -14,7 +14,7 @@ from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from .reducer import Reducer
@ -210,7 +210,7 @@ class ZeroDDP(ColoDDP):
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = ZeROHookV2(gemini_manager)
self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = []
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {}

View File

@ -1,11 +1,13 @@
import torch
from colossalai.tensor.param_op_hook import ParamOpHook
from colossalai.gemini import TensorState
from enum import Enum
from typing import List
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import List
import torch
from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ParamOpHook
class TrainingPhase(Enum):
@ -13,7 +15,7 @@ class TrainingPhase(Enum):
BACKWARD = 1
class ZeROHookV2(ParamOpHook):
class GeminiZeROHook(ParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__()

View File

@ -9,4 +9,4 @@ colossalai.zero.utils
:maxdepth: 2
colossalai.zero.utils.zero_hook
colossalai.zero.utils.zero_hook_v2
colossalai.zero.utils.gemini_hook

View File

@ -1,5 +1,5 @@
colossalai.zero.utils.zero\_hook\_v2
====================================
.. automodule:: colossalai.zero.utils.zero_hook_v2
.. automodule:: colossalai.zero.utils.gemini_hook
:members: