From aba3db464db85e0ffa25ff30dfb2f5e3fc8d361c Mon Sep 17 00:00:00 2001 From: Zihao <804673818@qq.com> Date: Thu, 24 Nov 2022 15:22:51 +0800 Subject: [PATCH] [Gemini] ParamMemHook (#2008) --- colossalai/gemini/ophooks/param_trace_hook.py | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 colossalai/gemini/ophooks/param_trace_hook.py diff --git a/colossalai/gemini/ophooks/param_trace_hook.py b/colossalai/gemini/ophooks/param_trace_hook.py new file mode 100644 index 000000000..7b369bea9 --- /dev/null +++ b/colossalai/gemini/ophooks/param_trace_hook.py @@ -0,0 +1,81 @@ +from contextlib import contextmanager +from enum import Enum +from functools import partial +from typing import List + +import torch + +from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor +from colossalai.tensor.param_op_hook import ParamOpHook + + +class TrainingPhase(Enum): + FORWARD = 0 + BACKWARD = 1 + + +class ParamMemHook(ParamOpHook): + + def __init__(self) -> None: + super().__init__() + self._training_phase = TrainingPhase.FORWARD + self.mem_monitor = SyncCudaMemoryMonitor() + self._non_model_data_list = [] + self._model_data_list = [] + + def _move_params_to_dev(self, params, dev: str) -> int: + assert isinstance(dev, str), f"device should be a str not torch.device" + comm_volume = 0 + for p in params: + if p.data.device.type != dev: + p.data = p.data.to(dev) + comm_volume += p.data.numel() * p.data.element_size() + if p.grad is not None: + if p.grad.device.type != dev: + p.grad = p.grad.to(dev) + comm_volume += p.grad.numel() * p.grad.element_size() + return comm_volume + + def sample_model_data(self, params): + data_volume = 0 + for p in params: + data_volume += p.data.numel() * p.data.element_size() + if self._training_phase == TrainingPhase.BACKWARD: + # add param.grad, actually param.grad is None in this time + data_volume *= 2 + self._model_data_list.append(data_volume) + + def pre_op(self, params): + cuda_volume = self.mem_monitor.finish() + if len(self._model_data_list): + self._non_model_data_list.append(cuda_volume - self._model_data_list[-1]) + self._move_params_to_dev(params, 'cuda') + self.sample_model_data(params) + self.mem_monitor.start() + + def post_op(self, params): + self._move_params_to_dev(params, 'cpu') + + def pre_forward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_forward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + def pre_backward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_backward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + @contextmanager + def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + old_training_phase = self._training_phase + try: + self._training_phase = training_phase + yield + finally: + self._training_phase = old_training_phase + + switch_to_backward = switch_training_phase + switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD) \ No newline at end of file