[hotfix] fix param op hook (#1131)

* fix param op hook

* update zero tp test

* fix bugs
pull/1133/head
ver217 2022-06-17 16:12:05 +08:00 committed by GitHub
parent a1a7899cae
commit 789cad301b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 20 deletions

View File

@ -11,17 +11,11 @@ def filter_args(func, *args):
return [arg for arg in args if func(arg)] return [arg for arg in args if func(arg)]
def unpack_args(*args):
if len(args) == 1:
return args[0]
return args
def replace_args(args, kwargs, new_args): def replace_args(args, kwargs, new_args):
args = new_args[:len(args)] args = new_args[:len(args)]
for k, v in zip(kwargs.keys(), new_args[len(args):]): for k, v in zip(kwargs.keys(), new_args[len(args):]):
kwargs[k] = v kwargs[k] = v
return unpack_args(args), kwargs return tuple(args), kwargs
class ColoParameter(ColoTensor, torch.nn.Parameter): class ColoParameter(ColoTensor, torch.nn.Parameter):

View File

@ -2,6 +2,7 @@ import torch
from contextlib import contextmanager from contextlib import contextmanager
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Any from typing import List, Tuple, Any
from colossalai.tensor.colo_tensor import ColoTensor
class ParamOpHook(ABC): class ParamOpHook(ABC):
@ -74,14 +75,18 @@ class ParamOpHookManager:
hook.post_backward(params) hook.post_backward(params)
@staticmethod @staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> Any: def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ParamOpHookManager._trigger_pre_forward(params) ParamOpHookManager._trigger_pre_forward(params)
return PreFwdPostBwd.apply(params, *args) args_info = _get_colo_tensors_info(*args)
rets = PreFwdPostBwd.apply(params, *args)
return _update_colo_tensors(args_info, *rets)
@staticmethod @staticmethod
def post_op(params: List[torch.Tensor], args: Any) -> Any: def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ParamOpHookManager._trigger_post_forward(params) ParamOpHookManager._trigger_post_forward(params)
return PostFwdPreBwd.apply(params, args) arg_info = _get_colo_tensors_info(arg)
ret = PostFwdPreBwd.apply(params, arg)
return _unpack_args(_update_colo_tensors(arg_info, ret))
@staticmethod @staticmethod
def has_hook() -> bool: def has_hook() -> bool:
@ -93,9 +98,7 @@ class PreFwdPostBwd(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, params, *args): def forward(ctx, params, *args):
ctx.params = params ctx.params = params
if len(args) == 1: return _unpack_args(args)
return args[0]
return args
@staticmethod @staticmethod
def backward(ctx, *grads): def backward(ctx, *grads):
@ -114,3 +117,29 @@ class PostFwdPreBwd(torch.autograd.Function):
def backward(ctx, *grads): def backward(ctx, *grads):
ParamOpHookManager._trigger_pre_backward(ctx.params) ParamOpHookManager._trigger_pre_backward(ctx.params)
return (None,) + grads return (None,) + grads
def _unpack_args(args):
if len(args) == 1:
return args[0]
return args
def _get_colo_tensors_info(*args) -> list:
info = []
for arg in args:
if isinstance(arg, ColoTensor):
info.append((arg.__class__, arg.spec))
else:
info.append(None)
return info
def _update_colo_tensors(info, *args) -> list:
ret = []
for t_info, arg in zip(info, args):
if t_info is not None:
t_cls, spec = t_info
arg = t_cls.from_torch_tensor(arg, spec=spec)
ret.append(arg)
return ret

View File

@ -10,7 +10,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from functools import partial from functools import partial
from _utils import tensor_equal, set_seed from _utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2 from colossalai.nn.parallel import ColoDDPV2
@ -19,19 +19,20 @@ from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
if p.storage().size() > 0: if p.storage().size() > 0:
assert p.dtype == torch.half assert p.dtype == torch.half
assert tensor_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}' assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}'
def check_grad_equal(model, torch_model): def check_grad_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
if p.grad is not None: if p.grad is not None:
assert tensor_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad) assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad)
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
@ -43,10 +44,30 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
return logits return logits
def init_1d_row_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_spec(spec)
def init_1d_col_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_spec(spec)
@parameterize('use_chunk', [False, True]) @parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True]) @parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('placement_policy', ['cuda', 'cpu'])
def run_gpt(use_chunk, use_zero, placement_policy): def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -58,6 +79,9 @@ def run_gpt(use_chunk, use_zero, placement_policy):
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p) torch_p.data.copy_(p)
if tp_init_spec_func:
tp_init_spec_func(model)
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, chunk_manager = ChunkManager(chunk_size,
enable_distributed_storage=use_zero, enable_distributed_storage=use_zero,
@ -90,8 +114,15 @@ def run_gpt(use_chunk, use_zero, placement_policy):
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') config = {}
run_gpt() if world_size == 4:
config['parallel'] = {'tensor': {'mode': '1d', 'size': 2}}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if world_size == 4:
run_gpt(tp_init_spec_func=init_1d_col_spec)
run_gpt(tp_init_spec_func=init_1d_row_spec)
else:
run_gpt()
@pytest.mark.dist @pytest.mark.dist