From 24586599193d4f18fbaf66174f6fd669d59bc9d2 Mon Sep 17 00:00:00 2001 From: HELSON Date: Mon, 26 Dec 2022 15:03:54 +0800 Subject: [PATCH] [zero] fix error for BEiT models (#2169) * [zero] fix error for BEiT models * [ColoParameter] add unpack operation for tuple arguments * fix bugs * fix chunkv2 unit testing * add assertion for gradient state --- colossalai/gemini/chunk/chunk.py | 6 +-- colossalai/nn/_ops/linear.py | 26 ++++++------- colossalai/nn/parallel/data_parallel.py | 7 +++- colossalai/tensor/colo_parameter.py | 23 +++++++++-- colossalai/tensor/colo_tensor.py | 2 +- colossalai/tensor/param_op_hook.py | 49 +++++++++++++++++++----- tests/test_gemini/update/test_chunkv2.py | 1 + 7 files changed, 82 insertions(+), 32 deletions(-) diff --git a/colossalai/gemini/chunk/chunk.py b/colossalai/gemini/chunk/chunk.py index a0b274197..a7682eaf6 100644 --- a/colossalai/gemini/chunk/chunk.py +++ b/colossalai/gemini/chunk/chunk.py @@ -18,9 +18,9 @@ class TensorState(Enum): STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), - (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), - (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), - (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE, + TensorState.HOLD), + (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, TensorState.HOLD)) diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py index 8835574de..2f2088c61 100644 --- a/colossalai/nn/_ops/linear.py +++ b/colossalai/nn/_ops/linear.py @@ -1,11 +1,13 @@ -import torch.nn.functional as F -from typing import Optional -from ._utils import GeneralTensor, convert_to_colo_tensor -from colossalai.tensor.op_wrapper import colo_op_impl -from ._utils import reduce_input, reduce_grad -from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec -from colossalai.tensor.sharding_spec import ShardingSpec from copy import deepcopy +from typing import Optional + +import torch.nn.functional as F + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor.sharding_spec import ShardingSpec + +from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': @@ -155,17 +157,15 @@ def _new_colo_linear_imp(input_tensor: GeneralTensor, def _has_sharding_spec(tensor): """ - A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is + A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is set as the attribute `sharding_spec` on a tensor. """ return hasattr(tensor, 'sharding_spec') @colo_op_impl(F.linear) -def colo_linear(input_tensor: GeneralTensor, - weight: GeneralTensor, - bias: Optional[GeneralTensor] = None) -> 'ColoTensor': +def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor': if _has_sharding_spec(weight): - return _new_colo_linear_imp(input_tensor, weight, bias) + return _new_colo_linear_imp(input, weight, bias) else: - return colo_linear_imp(input_tensor, weight, bias) + return colo_linear_imp(input, weight, bias) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 54f6eb9b7..8bd91050f 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -283,7 +283,9 @@ class ZeroDDP(ColoDDP): p.grad = None def _post_backward(self): - assert self.chunk_manager.accessed_mem == 0 + if self.chunk_manager.accessed_mem != 0: + raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", + "The most possible reason is that the model is not compatible with ZeroDDP.") self._setup_grads_ptr() self._logger.debug( f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' @@ -304,8 +306,9 @@ class ZeroDDP(ColoDDP): empty_grad = torch.empty_like(grad) free_storage(empty_grad) with torch._C.DisableTorchFunction(): - self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) chunk = self.chunk_manager.get_chunk(p) + assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD + self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) chunk.copy_tensor_to_chunk_slice(p, grad) reduced = self.chunk_manager.reduce_chunk(chunk) if reduced: diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 3e4c8ce69..92220d9e2 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -8,8 +8,25 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.tensor_spec import ColoTensorSpec -def filter_args(func, *args): - return [arg for arg in args if func(arg)] +def filter_colo_parameters(*args, **kwargs): + param_list = [] + + def get_colo_parameters(element) -> None: + if isinstance(element, list) or isinstance(element, tuple): + for e in element: + get_colo_parameters(e) + elif isinstance(element, dict): + raise RuntimeError("Found Dict: ColoParameter can't deal with complicated arguments.") + elif isinstance(element, ColoParameter): + param_list.append(element) + return + + for a in args: + get_colo_parameters(a) + for v in kwargs.values(): + get_colo_parameters(v) + + return param_list def replace_args(args, kwargs, new_args): @@ -62,7 +79,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): if not func.__name__.startswith('__'): if kwargs is None: kwargs = {} - params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values()) + params = filter_colo_parameters(*args, **kwargs) if len(params) > 0: with torch._C.DisableTorchFunction(): new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 7ecb407b5..670c210e3 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -57,7 +57,7 @@ class ColoTensor(torch.Tensor): The Colotensor can be initialized with a PyTorch tensor in the following ways. >>> pg = ProcessGroup() - >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()) + >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())) >>> # The tensor passed in is a tensor after sharding but not a global tensor. >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size), >>> dims=[0], diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 2320d98bc..7c73bc220 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -82,16 +82,26 @@ class ColoParamOpHookManager: @staticmethod def pre_op(params: List[torch.Tensor], *args: Any) -> list: ColoParamOpHookManager._trigger_pre_forward(params) - args_info = _get_colo_tensors_info(*args) - rets = PreFwdPostBwd.apply(params, *args) - return _update_colo_tensors(args_info, *rets) + grad_args, rear_args = _get_grad_args(*args) + colo_info = _get_colo_tensors_info(*grad_args) + rets = PreFwdPostBwd.apply(params, *grad_args) + update_args = _update_colo_tensors(colo_info, *rets) + if rear_args is None: + return update_args + else: + arg_zero = (tuple(update_args),) + return arg_zero + rear_args @staticmethod def post_op(params: List[torch.Tensor], arg: Any) -> Any: ColoParamOpHookManager._trigger_post_forward(params) - arg_info = _get_colo_tensors_info(arg) + colo_info = _get_colo_tensors_info(arg) ret = PostFwdPreBwd.apply(params, arg) - return _unpack_args(_update_colo_tensors(arg_info, ret)) + res = _update_colo_tensors(colo_info, ret) + if len(res) == 1: + return res[0] + else: + return res @staticmethod def has_hook() -> bool: @@ -103,7 +113,7 @@ class PreFwdPostBwd(torch.autograd.Function): @staticmethod def forward(ctx, params, *args): ctx.params = params - return _unpack_args(args) + return args @staticmethod def backward(ctx, *grads): @@ -124,10 +134,29 @@ class PostFwdPreBwd(torch.autograd.Function): return (None,) + grads -def _unpack_args(args): - if len(args) == 1: - return args[0] - return args +def _is_grad_tensor(obj) -> bool: + if torch.is_tensor(obj): + if obj.grad_fn is not None or obj.requires_grad: + return True + return False + + +def _get_grad_args(*args): + # returns the identical args if there is a grad tensor + for obj in args: + if _is_grad_tensor(obj): + return args, None + # otherwise, the first arguement should be a tuple of grad tensors + # if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered + arg_zero = args[0] + if not isinstance(arg_zero, tuple): + raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.") + check_grad_flag = False + for obj in arg_zero: + check_grad_flag |= _is_grad_tensor(obj) + if not check_grad_flag: + raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.") + return arg_zero, args[1:] def _get_colo_tensors_info(*args) -> list: diff --git a/tests/test_gemini/update/test_chunkv2.py b/tests/test_gemini/update/test_chunkv2.py index 48cae94e1..96855410b 100644 --- a/tests/test_gemini/update/test_chunkv2.py +++ b/tests/test_gemini/update/test_chunkv2.py @@ -90,6 +90,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): for param in param_list: my_chunk.tensor_trans_state(param, TensorState.COMPUTE) + my_chunk.tensor_trans_state(param, TensorState.HOLD_AFTER_BWD) my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4