ColossalAI/colossalai/zero/utils/zero_hook_v2.py

58 lines
1.9 KiB
Python
Raw Normal View History

import torch
from colossalai.tensor import ParamOpHook, ChunkManager, TensorState
from enum import Enum
from typing import List
from contextlib import contextmanager
from functools import partial
class TrainingPhase(Enum):
FORWARD = 0
BACKWARD = 1
class ZeROHookV2(ParamOpHook):
def __init__(self, chunk_manager: ChunkManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._chunk_manager.exec_lazy_release()
# TODO: evict chunks
for p in params:
self._chunk_manager.access_chunk(p)
def post_op(self, params):
for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
self._chunk_manager.trans_tensor_state(p, tensor_state)
self._chunk_manager.add_lazy_release_tensors(params)
def pre_forward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_forward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
def pre_backward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_backward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
@contextmanager
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
try:
old_training_phase = self._training_phase
self._training_phase = training_phase
yield
finally:
self._training_phase = old_training_phase
switch_to_backward = switch_training_phase
switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD)