[zero] fix error for BEiT models (#2169)

* [zero] fix error for BEiT models

* [ColoParameter] add unpack operation for tuple arguments

* fix bugs

* fix chunkv2 unit testing

* add assertion for gradient state
pull/2197/head
HELSON 2 years ago committed by GitHub
parent 4363ff3e41
commit 2458659919
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,9 +18,9 @@ class TensorState(Enum):
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE,
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), TensorState.HOLD),
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
TensorState.HOLD)) TensorState.HOLD))

@ -1,11 +1,13 @@
import torch.nn.functional as F from copy import deepcopy
from typing import Optional from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor
import torch.nn.functional as F
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input
def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
@ -155,17 +157,15 @@ def _new_colo_linear_imp(input_tensor: GeneralTensor,
def _has_sharding_spec(tensor): def _has_sharding_spec(tensor):
""" """
A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is
set as the attribute `sharding_spec` on a tensor. set as the attribute `sharding_spec` on a tensor.
""" """
return hasattr(tensor, 'sharding_spec') return hasattr(tensor, 'sharding_spec')
@colo_op_impl(F.linear) @colo_op_impl(F.linear)
def colo_linear(input_tensor: GeneralTensor, def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
weight: GeneralTensor,
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
if _has_sharding_spec(weight): if _has_sharding_spec(weight):
return _new_colo_linear_imp(input_tensor, weight, bias) return _new_colo_linear_imp(input, weight, bias)
else: else:
return colo_linear_imp(input_tensor, weight, bias) return colo_linear_imp(input, weight, bias)

@ -283,7 +283,9 @@ class ZeroDDP(ColoDDP):
p.grad = None p.grad = None
def _post_backward(self): def _post_backward(self):
assert self.chunk_manager.accessed_mem == 0 if self.chunk_manager.accessed_mem != 0:
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.")
self._setup_grads_ptr() self._setup_grads_ptr()
self._logger.debug( self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
@ -304,8 +306,9 @@ class ZeroDDP(ColoDDP):
empty_grad = torch.empty_like(grad) empty_grad = torch.empty_like(grad)
free_storage(empty_grad) free_storage(empty_grad)
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk = self.chunk_manager.get_chunk(p) chunk = self.chunk_manager.get_chunk(p)
assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad) chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk) reduced = self.chunk_manager.reduce_chunk(chunk)
if reduced: if reduced:

@ -8,8 +8,25 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec from colossalai.tensor.tensor_spec import ColoTensorSpec
def filter_args(func, *args): def filter_colo_parameters(*args, **kwargs):
return [arg for arg in args if func(arg)] param_list = []
def get_colo_parameters(element) -> None:
if isinstance(element, list) or isinstance(element, tuple):
for e in element:
get_colo_parameters(e)
elif isinstance(element, dict):
raise RuntimeError("Found Dict: ColoParameter can't deal with complicated arguments.")
elif isinstance(element, ColoParameter):
param_list.append(element)
return
for a in args:
get_colo_parameters(a)
for v in kwargs.values():
get_colo_parameters(v)
return param_list
def replace_args(args, kwargs, new_args): def replace_args(args, kwargs, new_args):
@ -62,7 +79,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
if not func.__name__.startswith('__'): if not func.__name__.startswith('__'):
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values()) params = filter_colo_parameters(*args, **kwargs)
if len(params) > 0: if len(params) > 0:
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())

@ -57,7 +57,7 @@ class ColoTensor(torch.Tensor):
The Colotensor can be initialized with a PyTorch tensor in the following ways. The Colotensor can be initialized with a PyTorch tensor in the following ways.
>>> pg = ProcessGroup() >>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()) >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()))
>>> # The tensor passed in is a tensor after sharding but not a global tensor. >>> # The tensor passed in is a tensor after sharding but not a global tensor.
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size), >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0], >>> dims=[0],

@ -82,16 +82,26 @@ class ColoParamOpHookManager:
@staticmethod @staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> list: def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ColoParamOpHookManager._trigger_pre_forward(params) ColoParamOpHookManager._trigger_pre_forward(params)
args_info = _get_colo_tensors_info(*args) grad_args, rear_args = _get_grad_args(*args)
rets = PreFwdPostBwd.apply(params, *args) colo_info = _get_colo_tensors_info(*grad_args)
return _update_colo_tensors(args_info, *rets) rets = PreFwdPostBwd.apply(params, *grad_args)
update_args = _update_colo_tensors(colo_info, *rets)
if rear_args is None:
return update_args
else:
arg_zero = (tuple(update_args),)
return arg_zero + rear_args
@staticmethod @staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any: def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ColoParamOpHookManager._trigger_post_forward(params) ColoParamOpHookManager._trigger_post_forward(params)
arg_info = _get_colo_tensors_info(arg) colo_info = _get_colo_tensors_info(arg)
ret = PostFwdPreBwd.apply(params, arg) ret = PostFwdPreBwd.apply(params, arg)
return _unpack_args(_update_colo_tensors(arg_info, ret)) res = _update_colo_tensors(colo_info, ret)
if len(res) == 1:
return res[0]
else:
return res
@staticmethod @staticmethod
def has_hook() -> bool: def has_hook() -> bool:
@ -103,7 +113,7 @@ class PreFwdPostBwd(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, params, *args): def forward(ctx, params, *args):
ctx.params = params ctx.params = params
return _unpack_args(args) return args
@staticmethod @staticmethod
def backward(ctx, *grads): def backward(ctx, *grads):
@ -124,10 +134,29 @@ class PostFwdPreBwd(torch.autograd.Function):
return (None,) + grads return (None,) + grads
def _unpack_args(args): def _is_grad_tensor(obj) -> bool:
if len(args) == 1: if torch.is_tensor(obj):
return args[0] if obj.grad_fn is not None or obj.requires_grad:
return args return True
return False
def _get_grad_args(*args):
# returns the identical args if there is a grad tensor
for obj in args:
if _is_grad_tensor(obj):
return args, None
# otherwise, the first arguement should be a tuple of grad tensors
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
arg_zero = args[0]
if not isinstance(arg_zero, tuple):
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
check_grad_flag = False
for obj in arg_zero:
check_grad_flag |= _is_grad_tensor(obj)
if not check_grad_flag:
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
return arg_zero, args[1:]
def _get_colo_tensors_info(*args) -> list: def _get_colo_tensors_info(*args) -> list:

@ -90,6 +90,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
for param in param_list: for param in param_list:
my_chunk.tensor_trans_state(param, TensorState.COMPUTE) my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
my_chunk.tensor_trans_state(param, TensorState.HOLD_AFTER_BWD)
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4

Loading…
Cancel
Save