From 789cad301b74a164f83b48db4ca42bcc1e281c0a Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 17 Jun 2022 16:12:05 +0800 Subject: [PATCH] [hotfix] fix param op hook (#1131) * fix param op hook * update zero tp test * fix bugs --- colossalai/tensor/colo_parameter.py | 8 +----- colossalai/tensor/param_op_hook.py | 43 +++++++++++++++++++++++----- tests/test_tensor/test_zero_optim.py | 43 ++++++++++++++++++++++++---- 3 files changed, 74 insertions(+), 20 deletions(-) diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 02e7bc45e..54c044ebf 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -11,17 +11,11 @@ def filter_args(func, *args): return [arg for arg in args if func(arg)] -def unpack_args(*args): - if len(args) == 1: - return args[0] - return args - - def replace_args(args, kwargs, new_args): args = new_args[:len(args)] for k, v in zip(kwargs.keys(), new_args[len(args):]): kwargs[k] = v - return unpack_args(args), kwargs + return tuple(args), kwargs class ColoParameter(ColoTensor, torch.nn.Parameter): diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 3741dbf67..fee6a0a6b 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -2,6 +2,7 @@ import torch from contextlib import contextmanager from abc import ABC, abstractmethod from typing import List, Tuple, Any +from colossalai.tensor.colo_tensor import ColoTensor class ParamOpHook(ABC): @@ -74,14 +75,18 @@ class ParamOpHookManager: hook.post_backward(params) @staticmethod - def pre_op(params: List[torch.Tensor], *args: Any) -> Any: + def pre_op(params: List[torch.Tensor], *args: Any) -> list: ParamOpHookManager._trigger_pre_forward(params) - return PreFwdPostBwd.apply(params, *args) + args_info = _get_colo_tensors_info(*args) + rets = PreFwdPostBwd.apply(params, *args) + return _update_colo_tensors(args_info, *rets) @staticmethod - def post_op(params: List[torch.Tensor], args: Any) -> Any: + def post_op(params: List[torch.Tensor], arg: Any) -> Any: ParamOpHookManager._trigger_post_forward(params) - return PostFwdPreBwd.apply(params, args) + arg_info = _get_colo_tensors_info(arg) + ret = PostFwdPreBwd.apply(params, arg) + return _unpack_args(_update_colo_tensors(arg_info, ret)) @staticmethod def has_hook() -> bool: @@ -93,9 +98,7 @@ class PreFwdPostBwd(torch.autograd.Function): @staticmethod def forward(ctx, params, *args): ctx.params = params - if len(args) == 1: - return args[0] - return args + return _unpack_args(args) @staticmethod def backward(ctx, *grads): @@ -114,3 +117,29 @@ class PostFwdPreBwd(torch.autograd.Function): def backward(ctx, *grads): ParamOpHookManager._trigger_pre_backward(ctx.params) return (None,) + grads + + +def _unpack_args(args): + if len(args) == 1: + return args[0] + return args + + +def _get_colo_tensors_info(*args) -> list: + info = [] + for arg in args: + if isinstance(arg, ColoTensor): + info.append((arg.__class__, arg.spec)) + else: + info.append(None) + return info + + +def _update_colo_tensors(info, *args) -> list: + ret = [] + for t_info, arg in zip(info, args): + if t_info is not None: + t_cls, spec = t_info + arg = t_cls.from_torch_tensor(arg, spec=spec) + ret.append(arg) + return ret diff --git a/tests/test_tensor/test_zero_optim.py b/tests/test_tensor/test_zero_optim.py index a86735815..cdcfc4641 100644 --- a/tests/test_tensor/test_zero_optim.py +++ b/tests/test_tensor/test_zero_optim.py @@ -10,7 +10,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.tensor import ChunkManager from colossalai.core import global_context as gpc from functools import partial -from _utils import tensor_equal, set_seed +from _utils import tensor_equal, set_seed, tensor_shard_equal from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP from colossalai.nn.parallel import ColoDDPV2 @@ -19,19 +19,20 @@ from colossalai.zero import ZeroOptimizer from colossalai.testing import parameterize from colossalai.amp import convert_to_apex_amp from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec def check_param_equal(model, torch_model): for p, torch_p in zip(model.parameters(), torch_model.parameters()): if p.storage().size() > 0: assert p.dtype == torch.half - assert tensor_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}' + assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}' def check_grad_equal(model, torch_model): for p, torch_p in zip(model.parameters(), torch_model.parameters()): if p.grad is not None: - assert tensor_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad) + assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad) def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): @@ -43,10 +44,30 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): return logits +def init_1d_row_spec(model): + spec = TensorSpec( + distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + ParallelAction(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if 'weight' in n and 'ln' not in n: + p.set_spec(spec) + + +def init_1d_col_spec(model): + spec = TensorSpec( + distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + ParallelAction(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if 'ln' not in n and ('weight' in n or 'bias' in n): + p.set_spec(spec) + + @parameterize('use_chunk', [False, True]) @parameterize('use_zero', [False, True]) @parameterize('placement_policy', ['cuda', 'cpu']) -def run_gpt(use_chunk, use_zero, placement_policy): +def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -58,6 +79,9 @@ def run_gpt(use_chunk, use_zero, placement_policy): for torch_p, p in zip(torch_model.parameters(), model.parameters()): torch_p.data.copy_(p) + if tp_init_spec_func: + tp_init_spec_func(model) + chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, @@ -90,8 +114,15 @@ def run_gpt(use_chunk, use_zero, placement_policy): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_gpt() + config = {} + if world_size == 4: + config['parallel'] = {'tensor': {'mode': '1d', 'size': 2}} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + if world_size == 4: + run_gpt(tp_init_spec_func=init_1d_col_spec) + run_gpt(tp_init_spec_func=init_1d_row_spec) + else: + run_gpt() @pytest.mark.dist