mirror of https://github.com/hpcaitech/ColossalAI
[tensor] refactor param op hook (#1097)
* refactor param op hook * add docstr * fix bugpull/1103/head
parent
1e9f9c227f
commit
895c1c5ee7
|
@ -4,8 +4,8 @@ from colossalai.core import global_context as gpc
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||||
from colossalai.tensor.chunk import ChunkManager, TensorState, Chunk
|
from colossalai.tensor.chunk import TensorState, Chunk
|
||||||
from colossalai.tensor.param_op_hook import use_param_op_hooks
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP):
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
self.module.zero_grad(set_to_none=True)
|
self.module.zero_grad(set_to_none=True)
|
||||||
self.gemini_manager.pre_iter()
|
self.gemini_manager.pre_iter()
|
||||||
with use_param_op_hooks(self.param_op_hook):
|
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
self.chunk_manager.exec_lazy_release()
|
self.chunk_manager.exec_lazy_release()
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -134,12 +134,12 @@ class ColoDDPV2(ColoDDP):
|
||||||
self.gemini_manager.post_iter()
|
self.gemini_manager.post_iter()
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor):
|
def backward(self, loss: torch.Tensor):
|
||||||
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
|
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self._post_backward()
|
self._post_backward()
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad):
|
||||||
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
|
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
torch.autograd.backward(tensor, grad)
|
torch.autograd.backward(tensor, grad)
|
||||||
self._post_backward()
|
self._post_backward()
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,10 @@ from .colo_parameter import ColoParameter
|
||||||
from .utils import convert_parameter, named_params_with_colotensor
|
from .utils import convert_parameter, named_params_with_colotensor
|
||||||
from . import distspec
|
from . import distspec
|
||||||
from .dist_spec_mgr import DistSpecManager
|
from .dist_spec_mgr import DistSpecManager
|
||||||
from .param_op_hook import ParamOpHook, use_param_op_hooks
|
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||||
from .chunk import ChunkManager, TensorState
|
from .chunk import ChunkManager, TensorState
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'named_params_with_colotensor',
|
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'named_params_with_colotensor',
|
||||||
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'use_param_op_hooks', 'ChunkManager', 'TensorState'
|
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState'
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,7 +3,7 @@ from colossalai.tensor.const import TensorType
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import TensorSpec, distspec
|
from colossalai.tensor import TensorSpec, distspec
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from colossalai.tensor.param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,17 +48,17 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||||
if len(_ParamOpHookWrapper.hooks) > 0:
|
if ParamOpHookManager.has_hook():
|
||||||
if not func.__name__.startswith('__'):
|
if not func.__name__.startswith('__'):
|
||||||
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
|
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
|
||||||
if kwargs is not None:
|
if kwargs is not None:
|
||||||
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
|
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
|
||||||
if len(params) > 0:
|
if len(params) > 0:
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
args = PreFwdPostBwd.apply(params, *args)
|
args = ParamOpHookManager.pre_op(params, *args)
|
||||||
ret = super().__torch_function__(func, types, args, kwargs)
|
ret = super().__torch_function__(func, types, args, kwargs)
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
ret = PostFwdPreBwd.apply(params, ret)
|
ret = ParamOpHookManager.post_op(params, ret)
|
||||||
return ret
|
return ret
|
||||||
return super().__torch_function__(func, types, args, kwargs)
|
return super().__torch_function__(func, types, args, kwargs)
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
import torch
|
import torch
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, Any
|
||||||
|
|
||||||
|
|
||||||
class ParamOpHook(ABC):
|
class ParamOpHook(ABC):
|
||||||
|
"""Hook which is triggered by each operation when operands contain ColoParameter.
|
||||||
|
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
|
||||||
|
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
|
||||||
|
of ColoParameter.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
||||||
|
@ -23,25 +28,78 @@ class ParamOpHook(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class _ParamOpHookWrapper:
|
class ParamOpHookManager:
|
||||||
|
"""Manage your param op hooks. It only has static methods.
|
||||||
|
The only static method you should call is ``use_hooks(*hooks)``.
|
||||||
|
"""
|
||||||
hooks: Tuple[ParamOpHook, ...] = tuple()
|
hooks: Tuple[ParamOpHook, ...] = tuple()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@contextmanager
|
||||||
|
def use_hooks(*hooks: ParamOpHook):
|
||||||
|
"""Change the param op hooks you use. Nested calling is allowed.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
>>> with ParamOpHookManager.use_hooks(*hooks):
|
||||||
|
>>> do_something()
|
||||||
|
>>> with ParamOpHookManager.use_hooks():
|
||||||
|
>>> // clear hooks
|
||||||
|
>>> do_something()
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
old_param_op_hooks = ParamOpHookManager.hooks
|
||||||
|
ParamOpHookManager.hooks = hooks
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
ParamOpHookManager.hooks = old_param_op_hooks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
|
||||||
|
for hook in ParamOpHookManager.hooks:
|
||||||
|
hook.pre_forward(params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
|
||||||
|
for hook in ParamOpHookManager.hooks:
|
||||||
|
hook.post_forward(params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
|
||||||
|
for hook in ParamOpHookManager.hooks:
|
||||||
|
hook.pre_backward(params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
|
||||||
|
for hook in ParamOpHookManager.hooks:
|
||||||
|
hook.post_backward(params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pre_op(params: List[torch.Tensor], *args: Any) -> Any:
|
||||||
|
ParamOpHookManager._trigger_pre_forward(params)
|
||||||
|
return PreFwdPostBwd.apply(params, *args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def post_op(params: List[torch.Tensor], args: Any) -> Any:
|
||||||
|
ParamOpHookManager._trigger_post_forward(params)
|
||||||
|
return PostFwdPreBwd.apply(params, args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def has_hook() -> bool:
|
||||||
|
return len(ParamOpHookManager.hooks) > 0
|
||||||
|
|
||||||
|
|
||||||
class PreFwdPostBwd(torch.autograd.Function):
|
class PreFwdPostBwd(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, params, *args):
|
def forward(ctx, params, *args):
|
||||||
ctx.params = params
|
ctx.params = params
|
||||||
for hook in _ParamOpHookWrapper.hooks:
|
|
||||||
hook.pre_forward(ctx.params)
|
|
||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
return args[0]
|
return args[0]
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *grads):
|
def backward(ctx, *grads):
|
||||||
for hook in _ParamOpHookWrapper.hooks:
|
ParamOpHookManager._trigger_post_backward(ctx.params)
|
||||||
hook.post_backward(ctx.params)
|
|
||||||
return (None,) + grads
|
return (None,) + grads
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,22 +108,9 @@ class PostFwdPreBwd(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, params, args):
|
def forward(ctx, params, args):
|
||||||
ctx.params = params
|
ctx.params = params
|
||||||
for hook in _ParamOpHookWrapper.hooks:
|
|
||||||
hook.post_forward(params)
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *grads):
|
def backward(ctx, *grads):
|
||||||
for hook in _ParamOpHookWrapper.hooks:
|
ParamOpHookManager._trigger_pre_backward(ctx.params)
|
||||||
hook.pre_backward(ctx.params)
|
|
||||||
return (None,) + grads
|
return (None,) + grads
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def use_param_op_hooks(*hooks: ParamOpHook):
|
|
||||||
try:
|
|
||||||
old_param_op_hooks = _ParamOpHookWrapper.hooks
|
|
||||||
_ParamOpHookWrapper.hooks = hooks
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
_ParamOpHookWrapper.hooks = old_param_op_hooks
|
|
||||||
|
|
Loading…
Reference in New Issue