You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/zero/gemini/gemini_hook.py

145 lines
5.6 KiB

from chunk import Chunk
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import Dict, List
import torch
import torch.distributed as dist
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator
class TrainingPhase(Enum):
FORWARD = 0
BACKWARD = 1
DEBUG = True # TODO @botbw: remove
class GeminiZeROHook(ColoParamOpHook):
def __init__(
self, gemini_manager: GeminiManager, param_order: OrderedParamGenerator, max_prefetch: int = 0
) -> None:
super().__init__()
self._gemini_manager = gemini_manager
self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD
self._cur_param = None
# param_visited_order might be updated somewhere else
self._param_visited_order = param_order.param_visited_order
self._max_prefetch = max_prefetch
self._async_works: Dict[Chunk, dist.work] = {}
# used by get_prefetch_chunks to track current param
self._cur_param_idx = 0
def get_prefetch_chunks(self, all_params: List[ColoParameter]) -> List[Chunk]:
chunks_to_prefetch = set()
if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase
self._cur_param_idx += len(all_params) # need to update first
idx = self._cur_param_idx + 1
# still have params and prefetched chunks don't exceed the limit
while idx < len(self._param_visited_order) and len(chunks_to_prefetch) + 1 < self._max_prefetch:
param = self._param_visited_order[idx]
if is_ddp_ignored(param):
idx += 1
continue
chunk = self._chunk_manager.get_chunk(param)
chunks_to_prefetch.add(chunk)
idx += 1
else:
assert self._training_phase == TrainingPhase.BACKWARD
self._cur_param_idx -= len(all_params)
idx = self._cur_param_idx - 1
chunks_to_prefetch = set()
while idx >= 0 and len(chunks_to_prefetch) + 1 < self._max_prefetch:
param = self._param_visited_order[idx]
if is_ddp_ignored(param):
idx -= 1
continue
chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx])
chunks_to_prefetch.add(chunk)
idx -= 1
return list(chunks_to_prefetch)
def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
non_prefetched_chunks = []
for chunk in chunks:
if chunk in self._async_works:
self._async_works[chunk].wait()
del self._async_works[chunk]
else:
non_prefetched_chunks.append(chunk)
return non_prefetched_chunks
def pre_op(self, all_params):
if DEBUG: # TODO @botbw: remove
idxs = list(map(lambda x: self._linked_param_order.param_visited_order.index(x), all_params))
mx = max(idxs)
idxs = sorted(map(lambda x: x - mx, idxs))
assert list(range(len(idxs))) == idxs, f"{idxs=}"
# deal with current needed chunks
params = [p for p in all_params if not is_ddp_ignored(p)]
all_chunks = self._chunk_manager.get_chunks(params)
chunks_wo_work = self.wait_chunks(all_chunks)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks_wo_work)
# deal with chunks that are to be async fetched
prefetch_chunks = self.get_prefetch_chunks(all_params)
# deal with chunks that are to be fetched now
for chunk in chunks_wo_work:
self._chunk_manager.access_chunk(chunk)
# deal with chunks that are to be pre fetched TODO @botbw: the order here matters?
for chunk in prefetch_chunks:
self._async_works[chunk] = self._chunk_manager.access_chunk(chunk, async_access=True)
# record cuda model data of the current OP
self._gemini_manager.record_model_data_volume()
def post_op(self, params):
params = [p for p in params if not is_ddp_ignored(p)]
for p in params:
tensor_state = (
TensorState.HOLD
if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad
else TensorState.HOLD_AFTER_BWD
)
self._chunk_manager.trans_tensor_state(p, tensor_state)
def pre_forward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_forward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
def pre_backward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_backward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
@contextmanager
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
old_training_phase = self._training_phase
try:
self._training_phase = training_phase
yield
finally:
self._training_phase = old_training_phase
switch_to_backward = switch_training_phase
switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD)