mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix param op hook (#1131)
* fix param op hook * update zero tp test * fix bugspull/1133/head
parent
a1a7899cae
commit
789cad301b
|
@ -11,17 +11,11 @@ def filter_args(func, *args):
|
||||||
return [arg for arg in args if func(arg)]
|
return [arg for arg in args if func(arg)]
|
||||||
|
|
||||||
|
|
||||||
def unpack_args(*args):
|
|
||||||
if len(args) == 1:
|
|
||||||
return args[0]
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def replace_args(args, kwargs, new_args):
|
def replace_args(args, kwargs, new_args):
|
||||||
args = new_args[:len(args)]
|
args = new_args[:len(args)]
|
||||||
for k, v in zip(kwargs.keys(), new_args[len(args):]):
|
for k, v in zip(kwargs.keys(), new_args[len(args):]):
|
||||||
kwargs[k] = v
|
kwargs[k] = v
|
||||||
return unpack_args(args), kwargs
|
return tuple(args), kwargs
|
||||||
|
|
||||||
|
|
||||||
class ColoParameter(ColoTensor, torch.nn.Parameter):
|
class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Any
|
from typing import List, Tuple, Any
|
||||||
|
from colossalai.tensor.colo_tensor import ColoTensor
|
||||||
|
|
||||||
|
|
||||||
class ParamOpHook(ABC):
|
class ParamOpHook(ABC):
|
||||||
|
@ -74,14 +75,18 @@ class ParamOpHookManager:
|
||||||
hook.post_backward(params)
|
hook.post_backward(params)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pre_op(params: List[torch.Tensor], *args: Any) -> Any:
|
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
||||||
ParamOpHookManager._trigger_pre_forward(params)
|
ParamOpHookManager._trigger_pre_forward(params)
|
||||||
return PreFwdPostBwd.apply(params, *args)
|
args_info = _get_colo_tensors_info(*args)
|
||||||
|
rets = PreFwdPostBwd.apply(params, *args)
|
||||||
|
return _update_colo_tensors(args_info, *rets)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_op(params: List[torch.Tensor], args: Any) -> Any:
|
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
||||||
ParamOpHookManager._trigger_post_forward(params)
|
ParamOpHookManager._trigger_post_forward(params)
|
||||||
return PostFwdPreBwd.apply(params, args)
|
arg_info = _get_colo_tensors_info(arg)
|
||||||
|
ret = PostFwdPreBwd.apply(params, arg)
|
||||||
|
return _unpack_args(_update_colo_tensors(arg_info, ret))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has_hook() -> bool:
|
def has_hook() -> bool:
|
||||||
|
@ -93,9 +98,7 @@ class PreFwdPostBwd(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, params, *args):
|
def forward(ctx, params, *args):
|
||||||
ctx.params = params
|
ctx.params = params
|
||||||
if len(args) == 1:
|
return _unpack_args(args)
|
||||||
return args[0]
|
|
||||||
return args
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *grads):
|
def backward(ctx, *grads):
|
||||||
|
@ -114,3 +117,29 @@ class PostFwdPreBwd(torch.autograd.Function):
|
||||||
def backward(ctx, *grads):
|
def backward(ctx, *grads):
|
||||||
ParamOpHookManager._trigger_pre_backward(ctx.params)
|
ParamOpHookManager._trigger_pre_backward(ctx.params)
|
||||||
return (None,) + grads
|
return (None,) + grads
|
||||||
|
|
||||||
|
|
||||||
|
def _unpack_args(args):
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def _get_colo_tensors_info(*args) -> list:
|
||||||
|
info = []
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, ColoTensor):
|
||||||
|
info.append((arg.__class__, arg.spec))
|
||||||
|
else:
|
||||||
|
info.append(None)
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def _update_colo_tensors(info, *args) -> list:
|
||||||
|
ret = []
|
||||||
|
for t_info, arg in zip(info, args):
|
||||||
|
if t_info is not None:
|
||||||
|
t_cls, spec = t_info
|
||||||
|
arg = t_cls.from_torch_tensor(arg, spec=spec)
|
||||||
|
ret.append(arg)
|
||||||
|
return ret
|
||||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import ChunkManager
|
from colossalai.tensor import ChunkManager
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from _utils import tensor_equal, set_seed
|
from _utils import tensor_equal, set_seed, tensor_shard_equal
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from colossalai.nn.parallel import ColoDDPV2
|
from colossalai.nn.parallel import ColoDDPV2
|
||||||
|
@ -19,19 +19,20 @@ from colossalai.zero import ZeroOptimizer
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize
|
||||||
from colossalai.amp import convert_to_apex_amp
|
from colossalai.amp import convert_to_apex_amp
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
|
||||||
|
|
||||||
|
|
||||||
def check_param_equal(model, torch_model):
|
def check_param_equal(model, torch_model):
|
||||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
if p.storage().size() > 0:
|
if p.storage().size() > 0:
|
||||||
assert p.dtype == torch.half
|
assert p.dtype == torch.half
|
||||||
assert tensor_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}'
|
assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}'
|
||||||
|
|
||||||
|
|
||||||
def check_grad_equal(model, torch_model):
|
def check_grad_equal(model, torch_model):
|
||||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
if p.grad is not None:
|
if p.grad is not None:
|
||||||
assert tensor_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad)
|
assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad)
|
||||||
|
|
||||||
|
|
||||||
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||||
|
@ -43,10 +44,30 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def init_1d_row_spec(model):
|
||||||
|
spec = TensorSpec(
|
||||||
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
ParallelAction(ComputePattern.TP1D))
|
||||||
|
with DistSpecManager.no_grad():
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if 'weight' in n and 'ln' not in n:
|
||||||
|
p.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
|
def init_1d_col_spec(model):
|
||||||
|
spec = TensorSpec(
|
||||||
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
ParallelAction(ComputePattern.TP1D))
|
||||||
|
with DistSpecManager.no_grad():
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||||
|
p.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
@parameterize('use_chunk', [False, True])
|
@parameterize('use_chunk', [False, True])
|
||||||
@parameterize('use_zero', [False, True])
|
@parameterize('use_zero', [False, True])
|
||||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||||
def run_gpt(use_chunk, use_zero, placement_policy):
|
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
@ -58,6 +79,9 @@ def run_gpt(use_chunk, use_zero, placement_policy):
|
||||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||||
torch_p.data.copy_(p)
|
torch_p.data.copy_(p)
|
||||||
|
|
||||||
|
if tp_init_spec_func:
|
||||||
|
tp_init_spec_func(model)
|
||||||
|
|
||||||
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
|
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
|
||||||
chunk_manager = ChunkManager(chunk_size,
|
chunk_manager = ChunkManager(chunk_size,
|
||||||
enable_distributed_storage=use_zero,
|
enable_distributed_storage=use_zero,
|
||||||
|
@ -90,8 +114,15 @@ def run_gpt(use_chunk, use_zero, placement_policy):
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
config = {}
|
||||||
run_gpt()
|
if world_size == 4:
|
||||||
|
config['parallel'] = {'tensor': {'mode': '1d', 'size': 2}}
|
||||||
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
if world_size == 4:
|
||||||
|
run_gpt(tp_init_spec_func=init_1d_col_spec)
|
||||||
|
run_gpt(tp_init_spec_func=init_1d_row_spec)
|
||||||
|
else:
|
||||||
|
run_gpt()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
Loading…
Reference in New Issue