ColossalAI/colossalai/tensor/param_op_hook.py

156 lines
5.0 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, List, Tuple
import torch
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
class ColoParamOpHook(ABC):
"""
Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``.
These four methods apply a list of ColoParameter as input args.
"""
@abstractmethod
def pre_forward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def post_forward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def pre_backward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def post_backward(self, params: List[torch.Tensor]) -> None:
pass
class ColoParamOpHookManager:
"""
Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
"""
hooks: Tuple[ColoParamOpHook, ...] = tuple()
@staticmethod
@contextmanager
def use_hooks(*hooks: ColoParamOpHook):
"""Change the param op hooks you use. Nested calling is allowed.
Example:
>>> with ColoParamOpHookManager.use_hooks(*hooks):
>>> do_something()
>>> with ColoParamOpHookManager.use_hooks():
>>> // clear hooks
>>> do_something()
"""
try:
old_param_op_hooks = ColoParamOpHookManager.hooks
ColoParamOpHookManager.hooks = hooks
yield
finally:
ColoParamOpHookManager.hooks = old_param_op_hooks
@staticmethod
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
for hook in ColoParamOpHookManager.hooks:
hook.pre_forward(params)
@staticmethod
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
for hook in ColoParamOpHookManager.hooks:
hook.post_forward(params)
@staticmethod
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
for hook in ColoParamOpHookManager.hooks:
hook.pre_backward(params)
@staticmethod
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
for hook in ColoParamOpHookManager.hooks:
hook.post_backward(params)
@staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ColoParamOpHookManager._trigger_pre_forward(params)
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
# auto grad function can only recognize torch.Tensor, thus we have to flatten the input
# if one of the input requires grad, all the output will be treated as requires grad
# and will have grad fn even the corresponding input does not require grad
# we have to extract tensors requiring grad into flat list and then merge them back
grad_args, other_args, grad_flags, spec = _flatten_grad_args(args)
new_grad_args = PreFwdPostBwd.apply(params, *grad_args)
return _merge_args(new_grad_args, other_args, grad_flags, spec)
@staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ColoParamOpHookManager._trigger_post_forward(params)
# incase the output is a tuple, we have to flatten it
grad_args, other_args, grad_flags, spec = _flatten_grad_args(arg)
new_grad_args = PostFwdPreBwd.apply(params, *grad_args)
return _merge_args(new_grad_args, other_args, grad_flags, spec)
@staticmethod
def has_hook() -> bool:
return len(ColoParamOpHookManager.hooks) > 0
class PreFwdPostBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
ctx.params = params
return args
@staticmethod
def backward(ctx, *grads):
ColoParamOpHookManager._trigger_post_backward(ctx.params)
return (None,) + grads
class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
ctx.params = params
return args
@staticmethod
def backward(ctx, *grads):
ColoParamOpHookManager._trigger_pre_backward(ctx.params)
return (None,) + grads
def _is_grad_tensor(obj) -> bool:
if torch.is_tensor(obj):
if obj.grad_fn is not None or obj.requires_grad:
return True
return False
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
flat_args, spec = tree_flatten(args)
grad_args = []
other_args = []
grad_flags = []
for arg in flat_args:
flag = _is_grad_tensor(arg)
grad_flags.append(flag)
if flag:
grad_args.append(arg)
else:
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
other_args.append(arg)
return grad_args, other_args, grad_flags, spec
def _merge_args(grad_args, other_args, grad_flags, spec):
grad_iter = iter(grad_args)
other_iter = iter(other_args)
flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags]
return tree_unflatten(flat_args, spec)