From 2dd01e3a1430f223b9ef8e61b73cf17f60fccb07 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Sun, 4 Feb 2024 11:58:26 +0800 Subject: [PATCH] [gemini] fix param op hook when output is tuple (#5355) * [gemini] fix param op hook when output is tuple * [gemini] fix param op hook --- colossalai/tensor/colo_parameter.py | 5 +++-- colossalai/tensor/param_op_hook.py | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 5301c87b9..acb9fc4ae 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -7,11 +7,12 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager from .colo_tensor import _convert_output -WHITE_LIST_FUNCS = {torch.Tensor.__getitem__, torch.Tensor.is_floating_point} +WHITE_LIST_FUNCS = {torch.Tensor.__getitem__} +NO_HOOK_FUNCS = {torch.Tensor.is_floating_point} def is_no_hook_op(func) -> bool: - return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS + return (func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS) or func in NO_HOOK_FUNCS def filter_colo_parameters(*args, **kwargs): diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 1fe99cd89..40de43c43 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -92,7 +92,10 @@ class ColoParamOpHookManager: @staticmethod def post_op(params: List[torch.Tensor], arg: Any) -> Any: ColoParamOpHookManager._trigger_post_forward(params) - return PostFwdPreBwd.apply(params, arg) + # incase the output is a tuple, we have to flatten it + grad_args, other_args, grad_flags, spec = _flatten_grad_args(arg) + new_grad_args = PostFwdPreBwd.apply(params, *grad_args) + return _merge_args(new_grad_args, other_args, grad_flags, spec) @staticmethod def has_hook() -> bool: @@ -113,7 +116,7 @@ class PreFwdPostBwd(torch.autograd.Function): class PostFwdPreBwd(torch.autograd.Function): @staticmethod - def forward(ctx, params, args): + def forward(ctx, params, *args): ctx.params = params return args @@ -142,7 +145,6 @@ def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]: grad_args.append(arg) else: other_args.append(arg) - assert len(grad_args) > 0 return grad_args, other_args, grad_flags, spec