[tensor] refactor param op hook (#1097)

* refactor param op hook

* add docstr

* fix bug
pull/1103/head
ver217 2022-06-13 16:11:53 +08:00 committed by GitHub
parent 1e9f9c227f
commit 895c1c5ee7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 31 deletions

View File

@ -4,8 +4,8 @@ from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from functools import partial from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.chunk import ChunkManager, TensorState, Chunk from colossalai.tensor.chunk import TensorState, Chunk
from colossalai.tensor.param_op_hook import use_param_op_hooks from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict from typing import Dict
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter() self.gemini_manager.pre_iter()
with use_param_op_hooks(self.param_op_hook): with ParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
self.chunk_manager.exec_lazy_release() self.chunk_manager.exec_lazy_release()
return outputs return outputs
@ -134,12 +134,12 @@ class ColoDDPV2(ColoDDP):
self.gemini_manager.post_iter() self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor): def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook): with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward() loss.backward()
self._post_backward() self._post_backward()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook): with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad) torch.autograd.backward(tensor, grad)
self._post_backward() self._post_backward()

View File

@ -5,10 +5,10 @@ from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor from .utils import convert_parameter, named_params_with_colotensor
from . import distspec from . import distspec
from .dist_spec_mgr import DistSpecManager from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, use_param_op_hooks from .param_op_hook import ParamOpHook, ParamOpHookManager
from .chunk import ChunkManager, TensorState from .chunk import ChunkManager, TensorState
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'named_params_with_colotensor', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'named_params_with_colotensor',
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'use_param_op_hooks', 'ChunkManager', 'TensorState' 'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState'
] ]

View File

@ -3,7 +3,7 @@ from colossalai.tensor.const import TensorType
import torch import torch
from colossalai.tensor import TensorSpec, distspec from colossalai.tensor import TensorSpec, distspec
from copy import copy from copy import copy
from colossalai.tensor.param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd from colossalai.tensor.param_op_hook import ParamOpHookManager
from typing import Optional from typing import Optional
@ -48,17 +48,17 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@classmethod @classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None): def __torch_function__(cls, func, types, args=..., kwargs=None):
if len(_ParamOpHookWrapper.hooks) > 0: if ParamOpHookManager.has_hook():
if not func.__name__.startswith('__'): if not func.__name__.startswith('__'):
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args)) params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
if kwargs is not None: if kwargs is not None:
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values()))) params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
if len(params) > 0: if len(params) > 0:
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
args = PreFwdPostBwd.apply(params, *args) args = ParamOpHookManager.pre_op(params, *args)
ret = super().__torch_function__(func, types, args, kwargs) ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
ret = PostFwdPreBwd.apply(params, ret) ret = ParamOpHookManager.post_op(params, ret)
return ret return ret
return super().__torch_function__(func, types, args, kwargs) return super().__torch_function__(func, types, args, kwargs)

View File

@ -1,10 +1,15 @@
import torch import torch
from contextlib import contextmanager from contextlib import contextmanager
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple from typing import List, Tuple, Any
class ParamOpHook(ABC): class ParamOpHook(ABC):
"""Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
of ColoParameter.
"""
@abstractmethod @abstractmethod
def pre_forward(self, params: List[torch.Tensor]) -> None: def pre_forward(self, params: List[torch.Tensor]) -> None:
@ -23,25 +28,78 @@ class ParamOpHook(ABC):
pass pass
class _ParamOpHookWrapper: class ParamOpHookManager:
"""Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
"""
hooks: Tuple[ParamOpHook, ...] = tuple() hooks: Tuple[ParamOpHook, ...] = tuple()
@staticmethod
@contextmanager
def use_hooks(*hooks: ParamOpHook):
"""Change the param op hooks you use. Nested calling is allowed.
Example::
>>> with ParamOpHookManager.use_hooks(*hooks):
>>> do_something()
>>> with ParamOpHookManager.use_hooks():
>>> // clear hooks
>>> do_something()
"""
try:
old_param_op_hooks = ParamOpHookManager.hooks
ParamOpHookManager.hooks = hooks
yield
finally:
ParamOpHookManager.hooks = old_param_op_hooks
@staticmethod
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
hook.pre_forward(params)
@staticmethod
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
hook.post_forward(params)
@staticmethod
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
hook.pre_backward(params)
@staticmethod
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
hook.post_backward(params)
@staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> Any:
ParamOpHookManager._trigger_pre_forward(params)
return PreFwdPostBwd.apply(params, *args)
@staticmethod
def post_op(params: List[torch.Tensor], args: Any) -> Any:
ParamOpHookManager._trigger_post_forward(params)
return PostFwdPreBwd.apply(params, args)
@staticmethod
def has_hook() -> bool:
return len(ParamOpHookManager.hooks) > 0
class PreFwdPostBwd(torch.autograd.Function): class PreFwdPostBwd(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, params, *args): def forward(ctx, params, *args):
ctx.params = params ctx.params = params
for hook in _ParamOpHookWrapper.hooks:
hook.pre_forward(ctx.params)
if len(args) == 1: if len(args) == 1:
return args[0] return args[0]
return args return args
@staticmethod @staticmethod
def backward(ctx, *grads): def backward(ctx, *grads):
for hook in _ParamOpHookWrapper.hooks: ParamOpHookManager._trigger_post_backward(ctx.params)
hook.post_backward(ctx.params)
return (None,) + grads return (None,) + grads
@ -50,22 +108,9 @@ class PostFwdPreBwd(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, params, args): def forward(ctx, params, args):
ctx.params = params ctx.params = params
for hook in _ParamOpHookWrapper.hooks:
hook.post_forward(params)
return args return args
@staticmethod @staticmethod
def backward(ctx, *grads): def backward(ctx, *grads):
for hook in _ParamOpHookWrapper.hooks: ParamOpHookManager._trigger_pre_backward(ctx.params)
hook.pre_backward(ctx.params)
return (None,) + grads return (None,) + grads
@contextmanager
def use_param_op_hooks(*hooks: ParamOpHook):
try:
old_param_op_hooks = _ParamOpHookWrapper.hooks
_ParamOpHookWrapper.hooks = hooks
yield
finally:
_ParamOpHookWrapper.hooks = old_param_op_hooks