diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 8ad5b8ba2..38091cd1f 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -4,8 +4,8 @@ from colossalai.core import global_context as gpc from colossalai.context import ParallelMode from functools import partial from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 -from colossalai.tensor.chunk import ChunkManager, TensorState, Chunk -from colossalai.tensor.param_op_hook import use_param_op_hooks +from colossalai.tensor.chunk import TensorState, Chunk +from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.gemini.gemini_mgr import GeminiManager from typing import Dict from colossalai.logging import get_dist_logger @@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP): def forward(self, *args, **kwargs): self.module.zero_grad(set_to_none=True) self.gemini_manager.pre_iter() - with use_param_op_hooks(self.param_op_hook): + with ParamOpHookManager.use_hooks(self.param_op_hook): outputs = self.module(*args, **kwargs) self.chunk_manager.exec_lazy_release() return outputs @@ -134,12 +134,12 @@ class ColoDDPV2(ColoDDP): self.gemini_manager.post_iter() def backward(self, loss: torch.Tensor): - with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook): + with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): loss.backward() self._post_backward() def backward_by_grad(self, tensor, grad): - with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook): + with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): torch.autograd.backward(tensor, grad) self._post_backward() diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index a13cfbec1..b30562044 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -5,10 +5,10 @@ from .colo_parameter import ColoParameter from .utils import convert_parameter, named_params_with_colotensor from . import distspec from .dist_spec_mgr import DistSpecManager -from .param_op_hook import ParamOpHook, use_param_op_hooks +from .param_op_hook import ParamOpHook, ParamOpHookManager from .chunk import ChunkManager, TensorState __all__ = [ 'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'named_params_with_colotensor', - 'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'use_param_op_hooks', 'ChunkManager', 'TensorState' + 'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState' ] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index ea040aca3..8e3e5f5d0 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -3,7 +3,7 @@ from colossalai.tensor.const import TensorType import torch from colossalai.tensor import TensorSpec, distspec from copy import copy -from colossalai.tensor.param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd +from colossalai.tensor.param_op_hook import ParamOpHookManager from typing import Optional @@ -48,17 +48,17 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): @classmethod def __torch_function__(cls, func, types, args=..., kwargs=None): - if len(_ParamOpHookWrapper.hooks) > 0: + if ParamOpHookManager.has_hook(): if not func.__name__.startswith('__'): params = list(filter(lambda arg: isinstance(arg, ColoParameter), args)) if kwargs is not None: params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values()))) if len(params) > 0: with torch._C.DisableTorchFunction(): - args = PreFwdPostBwd.apply(params, *args) + args = ParamOpHookManager.pre_op(params, *args) ret = super().__torch_function__(func, types, args, kwargs) with torch._C.DisableTorchFunction(): - ret = PostFwdPreBwd.apply(params, ret) + ret = ParamOpHookManager.post_op(params, ret) return ret return super().__torch_function__(func, types, args, kwargs) diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 7522b62c2..3741dbf67 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -1,10 +1,15 @@ import torch from contextlib import contextmanager from abc import ABC, abstractmethod -from typing import List, Tuple +from typing import List, Tuple, Any class ParamOpHook(ABC): + """Hook which is triggered by each operation when operands contain ColoParameter. + To customize it, you must inherit this abstract class, and implement ``pre_forward``, + ``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list + of ColoParameter. + """ @abstractmethod def pre_forward(self, params: List[torch.Tensor]) -> None: @@ -23,25 +28,78 @@ class ParamOpHook(ABC): pass -class _ParamOpHookWrapper: +class ParamOpHookManager: + """Manage your param op hooks. It only has static methods. + The only static method you should call is ``use_hooks(*hooks)``. + """ hooks: Tuple[ParamOpHook, ...] = tuple() + @staticmethod + @contextmanager + def use_hooks(*hooks: ParamOpHook): + """Change the param op hooks you use. Nested calling is allowed. + + Example:: + >>> with ParamOpHookManager.use_hooks(*hooks): + >>> do_something() + >>> with ParamOpHookManager.use_hooks(): + >>> // clear hooks + >>> do_something() + """ + try: + old_param_op_hooks = ParamOpHookManager.hooks + ParamOpHookManager.hooks = hooks + yield + finally: + ParamOpHookManager.hooks = old_param_op_hooks + + @staticmethod + def _trigger_pre_forward(params: List[torch.Tensor]) -> None: + for hook in ParamOpHookManager.hooks: + hook.pre_forward(params) + + @staticmethod + def _trigger_post_forward(params: List[torch.Tensor]) -> None: + for hook in ParamOpHookManager.hooks: + hook.post_forward(params) + + @staticmethod + def _trigger_pre_backward(params: List[torch.Tensor]) -> None: + for hook in ParamOpHookManager.hooks: + hook.pre_backward(params) + + @staticmethod + def _trigger_post_backward(params: List[torch.Tensor]) -> None: + for hook in ParamOpHookManager.hooks: + hook.post_backward(params) + + @staticmethod + def pre_op(params: List[torch.Tensor], *args: Any) -> Any: + ParamOpHookManager._trigger_pre_forward(params) + return PreFwdPostBwd.apply(params, *args) + + @staticmethod + def post_op(params: List[torch.Tensor], args: Any) -> Any: + ParamOpHookManager._trigger_post_forward(params) + return PostFwdPreBwd.apply(params, args) + + @staticmethod + def has_hook() -> bool: + return len(ParamOpHookManager.hooks) > 0 + class PreFwdPostBwd(torch.autograd.Function): @staticmethod def forward(ctx, params, *args): ctx.params = params - for hook in _ParamOpHookWrapper.hooks: - hook.pre_forward(ctx.params) if len(args) == 1: return args[0] return args @staticmethod def backward(ctx, *grads): - for hook in _ParamOpHookWrapper.hooks: - hook.post_backward(ctx.params) + ParamOpHookManager._trigger_post_backward(ctx.params) return (None,) + grads @@ -50,22 +108,9 @@ class PostFwdPreBwd(torch.autograd.Function): @staticmethod def forward(ctx, params, args): ctx.params = params - for hook in _ParamOpHookWrapper.hooks: - hook.post_forward(params) return args @staticmethod def backward(ctx, *grads): - for hook in _ParamOpHookWrapper.hooks: - hook.pre_backward(ctx.params) + ParamOpHookManager._trigger_pre_backward(ctx.params) return (None,) + grads - - -@contextmanager -def use_param_op_hooks(*hooks: ParamOpHook): - try: - old_param_op_hooks = _ParamOpHookWrapper.hooks - _ParamOpHookWrapper.hooks = hooks - yield - finally: - _ParamOpHookWrapper.hooks = old_param_op_hooks