mirror of https://github.com/hpcaitech/ColossalAI
[zero] fix error for BEiT models (#2169)
* [zero] fix error for BEiT models * [ColoParameter] add unpack operation for tuple arguments * fix bugs * fix chunkv2 unit testing * add assertion for gradient statepull/2197/head
parent
4363ff3e41
commit
2458659919
|
@ -18,9 +18,9 @@ class TensorState(Enum):
|
|||
|
||||
|
||||
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE,
|
||||
TensorState.HOLD),
|
||||
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
|
||||
TensorState.HOLD))
|
||||
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from ._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input
|
||||
|
||||
|
||||
def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||
|
@ -162,10 +164,8 @@ def _has_sharding_spec(tensor):
|
|||
|
||||
|
||||
@colo_op_impl(F.linear)
|
||||
def colo_linear(input_tensor: GeneralTensor,
|
||||
weight: GeneralTensor,
|
||||
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
||||
def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
||||
if _has_sharding_spec(weight):
|
||||
return _new_colo_linear_imp(input_tensor, weight, bias)
|
||||
return _new_colo_linear_imp(input, weight, bias)
|
||||
else:
|
||||
return colo_linear_imp(input_tensor, weight, bias)
|
||||
return colo_linear_imp(input, weight, bias)
|
||||
|
|
|
@ -283,7 +283,9 @@ class ZeroDDP(ColoDDP):
|
|||
p.grad = None
|
||||
|
||||
def _post_backward(self):
|
||||
assert self.chunk_manager.accessed_mem == 0
|
||||
if self.chunk_manager.accessed_mem != 0:
|
||||
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with ZeroDDP.")
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
||||
|
@ -304,8 +306,9 @@ class ZeroDDP(ColoDDP):
|
|||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
if reduced:
|
||||
|
|
|
@ -8,8 +8,25 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
|||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
|
||||
def filter_args(func, *args):
|
||||
return [arg for arg in args if func(arg)]
|
||||
def filter_colo_parameters(*args, **kwargs):
|
||||
param_list = []
|
||||
|
||||
def get_colo_parameters(element) -> None:
|
||||
if isinstance(element, list) or isinstance(element, tuple):
|
||||
for e in element:
|
||||
get_colo_parameters(e)
|
||||
elif isinstance(element, dict):
|
||||
raise RuntimeError("Found Dict: ColoParameter can't deal with complicated arguments.")
|
||||
elif isinstance(element, ColoParameter):
|
||||
param_list.append(element)
|
||||
return
|
||||
|
||||
for a in args:
|
||||
get_colo_parameters(a)
|
||||
for v in kwargs.values():
|
||||
get_colo_parameters(v)
|
||||
|
||||
return param_list
|
||||
|
||||
|
||||
def replace_args(args, kwargs, new_args):
|
||||
|
@ -62,7 +79,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||
if not func.__name__.startswith('__'):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
|
||||
params = filter_colo_parameters(*args, **kwargs)
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
|
|
|
@ -57,7 +57,7 @@ class ColoTensor(torch.Tensor):
|
|||
The Colotensor can be initialized with a PyTorch tensor in the following ways.
|
||||
|
||||
>>> pg = ProcessGroup()
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()))
|
||||
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
|
|
|
@ -82,16 +82,26 @@ class ColoParamOpHookManager:
|
|||
@staticmethod
|
||||
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
||||
ColoParamOpHookManager._trigger_pre_forward(params)
|
||||
args_info = _get_colo_tensors_info(*args)
|
||||
rets = PreFwdPostBwd.apply(params, *args)
|
||||
return _update_colo_tensors(args_info, *rets)
|
||||
grad_args, rear_args = _get_grad_args(*args)
|
||||
colo_info = _get_colo_tensors_info(*grad_args)
|
||||
rets = PreFwdPostBwd.apply(params, *grad_args)
|
||||
update_args = _update_colo_tensors(colo_info, *rets)
|
||||
if rear_args is None:
|
||||
return update_args
|
||||
else:
|
||||
arg_zero = (tuple(update_args),)
|
||||
return arg_zero + rear_args
|
||||
|
||||
@staticmethod
|
||||
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
||||
ColoParamOpHookManager._trigger_post_forward(params)
|
||||
arg_info = _get_colo_tensors_info(arg)
|
||||
colo_info = _get_colo_tensors_info(arg)
|
||||
ret = PostFwdPreBwd.apply(params, arg)
|
||||
return _unpack_args(_update_colo_tensors(arg_info, ret))
|
||||
res = _update_colo_tensors(colo_info, ret)
|
||||
if len(res) == 1:
|
||||
return res[0]
|
||||
else:
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def has_hook() -> bool:
|
||||
|
@ -103,7 +113,7 @@ class PreFwdPostBwd(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def forward(ctx, params, *args):
|
||||
ctx.params = params
|
||||
return _unpack_args(args)
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
|
@ -124,10 +134,29 @@ class PostFwdPreBwd(torch.autograd.Function):
|
|||
return (None,) + grads
|
||||
|
||||
|
||||
def _unpack_args(args):
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
def _is_grad_tensor(obj) -> bool:
|
||||
if torch.is_tensor(obj):
|
||||
if obj.grad_fn is not None or obj.requires_grad:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_grad_args(*args):
|
||||
# returns the identical args if there is a grad tensor
|
||||
for obj in args:
|
||||
if _is_grad_tensor(obj):
|
||||
return args, None
|
||||
# otherwise, the first arguement should be a tuple of grad tensors
|
||||
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
|
||||
arg_zero = args[0]
|
||||
if not isinstance(arg_zero, tuple):
|
||||
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
|
||||
check_grad_flag = False
|
||||
for obj in arg_zero:
|
||||
check_grad_flag |= _is_grad_tensor(obj)
|
||||
if not check_grad_flag:
|
||||
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
|
||||
return arg_zero, args[1:]
|
||||
|
||||
|
||||
def _get_colo_tensors_info(*args) -> list:
|
||||
|
|
|
@ -90,6 +90,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
|||
|
||||
for param in param_list:
|
||||
my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
|
||||
my_chunk.tensor_trans_state(param, TensorState.HOLD_AFTER_BWD)
|
||||
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
|
||||
|
||||
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
|
||||
|
|
Loading…
Reference in New Issue