From b3b89865e2f35a8aaefc4cbb66747c060f352851 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 5 Dec 2022 17:11:06 +0800 Subject: [PATCH] [Gemini] ParamOpHook -> ColoParamOpHook (#2080) --- .../memory_tracer/runtime_mem_tracer.py | 6 ++-- .../gemini/ophooks/runtime_mem_tracer_hook.py | 4 +-- colossalai/nn/parallel/data_parallel.py | 8 ++--- colossalai/tensor/__init__.py | 7 ++-- colossalai/tensor/colo_parameter.py | 8 ++--- colossalai/tensor/param_op_hook.py | 36 +++++++++---------- colossalai/zero/utils/gemini_hook.py | 4 +-- 7 files changed, 37 insertions(+), 36 deletions(-) diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py index ead95535e..277371a36 100644 --- a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -3,7 +3,7 @@ import torch.nn from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook from colossalai.nn.parallel.data_parallel import _cast_float -from colossalai.tensor.param_op_hook import ParamOpHookManager +from colossalai.tensor.param_op_hook import ColoParamOpHookManager __all__ = ['RuntimeMemTracer'] @@ -53,12 +53,12 @@ class RuntimeMemTracer(): args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype) self.module.zero_grad(set_to_none=True) self._pre_forward() - with ParamOpHookManager.use_hooks(self.param_op_hook): + with ColoParamOpHookManager.use_hooks(self.param_op_hook): outputs = self.module(*args, **kwargs) return outputs def backward(self, loss): - with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): loss.backward() self._post_backward() diff --git a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py index 5f155f085..5d8382ed0 100644 --- a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py @@ -8,7 +8,7 @@ import torch from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO from colossalai.gemini.tensor_utils import alloc_storage, free_storage -from colossalai.tensor.param_op_hook import ParamOpHook +from colossalai.tensor.param_op_hook import ColoParamOpHook class TrainingPhase(Enum): @@ -39,7 +39,7 @@ class GradMemTracerHook(): hook.remove() -class ParamMemTracerHook(ParamOpHook): +class ParamMemTracerHook(ColoParamOpHook): def __init__(self) -> None: super().__init__() diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 78b6b499e..175146ebb 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec -from colossalai.tensor.param_op_hook import ParamOpHookManager +from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import get_current_device from colossalai.zero.utils.gemini_hook import GeminiZeROHook @@ -259,7 +259,7 @@ class ZeroDDP(ColoDDP): args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) self.module.zero_grad(set_to_none=True) self.gemini_manager.pre_iter(*args) - with ParamOpHookManager.use_hooks(self.param_op_hook): + with ColoParamOpHookManager.use_hooks(self.param_op_hook): outputs = self.module(*args, **kwargs) if self.force_outputs_fp32: return _cast_float(outputs, torch.float) @@ -280,12 +280,12 @@ class ZeroDDP(ColoDDP): self.gemini_manager.post_iter() def backward(self, loss: torch.Tensor): - with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): loss.backward() self._post_backward() def backward_by_grad(self, tensor, grad): - with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): torch.autograd.backward(tensor, grad) self._post_backward() diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index ebccf7e18..b2da64e6c 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -5,13 +5,14 @@ from .comm_spec import CollectiveCommPattern, CommSpec from .compute_spec import ComputePattern, ComputeSpec from .dist_spec_mgr import DistSpecManager from .distspec import ReplicaSpec, ShardSpec -from .param_op_hook import ParamOpHook, ParamOpHookManager +from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager from .process_group import ProcessGroup from .tensor_spec import ColoTensorSpec from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor __all__ = [ 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', - 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec', - 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list' + 'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', + 'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', + 'merge_same_dim_mesh_list' ] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 7247ef966..3e4c8ce69 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -4,7 +4,7 @@ import torch from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.const import TensorType -from colossalai.tensor.param_op_hook import ParamOpHookManager +from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.tensor_spec import ColoTensorSpec @@ -58,18 +58,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): @classmethod def __torch_function__(cls, func, types, args=..., kwargs=None): - if ParamOpHookManager.has_hook(): + if ColoParamOpHookManager.has_hook(): if not func.__name__.startswith('__'): if kwargs is None: kwargs = {} params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values()) if len(params) > 0: with torch._C.DisableTorchFunction(): - new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values()) + new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) args, kwargs = replace_args(args, kwargs, new_args) ret = super().__torch_function__(func, types, args, kwargs) with torch._C.DisableTorchFunction(): - ret = ParamOpHookManager.post_op(params, ret) + ret = ColoParamOpHookManager.post_op(params, ret) return ret return super().__torch_function__(func, types, args, kwargs) diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 23fad971c..3b2cf7673 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -8,7 +8,7 @@ from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.tensor_spec import ColoTensorSpec -class ParamOpHook(ABC): +class ColoParamOpHook(ABC): """Hook which is triggered by each operation when operands contain ColoParameter. To customize it, you must inherit this abstract class, and implement ``pre_forward``, ``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list @@ -32,68 +32,68 @@ class ParamOpHook(ABC): pass -class ParamOpHookManager: +class ColoParamOpHookManager: """Manage your param op hooks. It only has static methods. The only static method you should call is ``use_hooks(*hooks)``. """ - hooks: Tuple[ParamOpHook, ...] = tuple() + hooks: Tuple[ColoParamOpHook, ...] = tuple() @staticmethod @contextmanager - def use_hooks(*hooks: ParamOpHook): + def use_hooks(*hooks: ColoParamOpHook): """Change the param op hooks you use. Nested calling is allowed. Example: - >>> with ParamOpHookManager.use_hooks(*hooks): + >>> with ColoParamOpHookManager.use_hooks(*hooks): >>> do_something() - >>> with ParamOpHookManager.use_hooks(): + >>> with ColoParamOpHookManager.use_hooks(): >>> // clear hooks >>> do_something() """ try: - old_param_op_hooks = ParamOpHookManager.hooks - ParamOpHookManager.hooks = hooks + old_param_op_hooks = ColoParamOpHookManager.hooks + ColoParamOpHookManager.hooks = hooks yield finally: - ParamOpHookManager.hooks = old_param_op_hooks + ColoParamOpHookManager.hooks = old_param_op_hooks @staticmethod def _trigger_pre_forward(params: List[torch.Tensor]) -> None: - for hook in ParamOpHookManager.hooks: + for hook in ColoParamOpHookManager.hooks: hook.pre_forward(params) @staticmethod def _trigger_post_forward(params: List[torch.Tensor]) -> None: - for hook in ParamOpHookManager.hooks: + for hook in ColoParamOpHookManager.hooks: hook.post_forward(params) @staticmethod def _trigger_pre_backward(params: List[torch.Tensor]) -> None: - for hook in ParamOpHookManager.hooks: + for hook in ColoParamOpHookManager.hooks: hook.pre_backward(params) @staticmethod def _trigger_post_backward(params: List[torch.Tensor]) -> None: - for hook in ParamOpHookManager.hooks: + for hook in ColoParamOpHookManager.hooks: hook.post_backward(params) @staticmethod def pre_op(params: List[torch.Tensor], *args: Any) -> list: - ParamOpHookManager._trigger_pre_forward(params) + ColoParamOpHookManager._trigger_pre_forward(params) args_info = _get_colo_tensors_info(*args) rets = PreFwdPostBwd.apply(params, *args) return _update_colo_tensors(args_info, *rets) @staticmethod def post_op(params: List[torch.Tensor], arg: Any) -> Any: - ParamOpHookManager._trigger_post_forward(params) + ColoParamOpHookManager._trigger_post_forward(params) arg_info = _get_colo_tensors_info(arg) ret = PostFwdPreBwd.apply(params, arg) return _unpack_args(_update_colo_tensors(arg_info, ret)) @staticmethod def has_hook() -> bool: - return len(ParamOpHookManager.hooks) > 0 + return len(ColoParamOpHookManager.hooks) > 0 class PreFwdPostBwd(torch.autograd.Function): @@ -105,7 +105,7 @@ class PreFwdPostBwd(torch.autograd.Function): @staticmethod def backward(ctx, *grads): - ParamOpHookManager._trigger_post_backward(ctx.params) + ColoParamOpHookManager._trigger_post_backward(ctx.params) return (None,) + grads @@ -118,7 +118,7 @@ class PostFwdPreBwd(torch.autograd.Function): @staticmethod def backward(ctx, *grads): - ParamOpHookManager._trigger_pre_backward(ctx.params) + ColoParamOpHookManager._trigger_pre_backward(ctx.params) return (None,) + grads diff --git a/colossalai/zero/utils/gemini_hook.py b/colossalai/zero/utils/gemini_hook.py index 4fbbcf376..99ca38495 100644 --- a/colossalai/zero/utils/gemini_hook.py +++ b/colossalai/zero/utils/gemini_hook.py @@ -7,7 +7,7 @@ import torch from colossalai.gemini import TensorState from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.tensor.param_op_hook import ParamOpHook +from colossalai.tensor.param_op_hook import ColoParamOpHook class TrainingPhase(Enum): @@ -15,7 +15,7 @@ class TrainingPhase(Enum): BACKWARD = 1 -class GeminiZeROHook(ParamOpHook): +class GeminiZeROHook(ColoParamOpHook): def __init__(self, gemini_manager: GeminiManager) -> None: super().__init__()