From cefc29ff066aa79d2f2254177a04ce0828489558 Mon Sep 17 00:00:00 2001 From: ver217 Date: Sat, 21 May 2022 13:52:04 +0800 Subject: [PATCH] [tensor] impl ColoDDP for ColoTensor (#1009) * impl ColoDDP for ColoTensor * polish code --- colossalai/nn/parallel.py | 78 +++++++++++++++++++++++++++++++++++ tests/test_tensor/test_gpt.py | 33 +++++++++++---- 2 files changed, 102 insertions(+), 9 deletions(-) create mode 100644 colossalai/nn/parallel.py diff --git a/colossalai/nn/parallel.py b/colossalai/nn/parallel.py new file mode 100644 index 000000000..9b1fa81a2 --- /dev/null +++ b/colossalai/nn/parallel.py @@ -0,0 +1,78 @@ +import torch +import torch.distributed as dist +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from functools import partial + +__all__ = ['ColoDDP'] + + +def free_storage(data: torch.Tensor) -> None: + """Free underlying storage of a Tensor.""" + if data.storage().size() > 0: + # Since we're modifying the Tensor's Storage directly, make sure the Tensor + # is the sole occupant of the Storage. + assert data.storage_offset() == 0 + data.storage().resize_(0) + + +class ColoDDP(torch.nn.Module): + + def __init__(self, module: torch.nn.Module) -> None: + super().__init__() + self.module = module + self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() + self.dp_world_size = gpc.get_world_size(ParallelMode.DATA) + for p in module.parameters(): + if p.requires_grad: + p.register_hook(partial(self.grad_handle, p)) + + def parameters(self, recurse: bool = True): + return self.module.parameters(recurse) + + def named_parameters(self, prefix: str = '', recurse: bool = True): + return self.module.named_parameters(prefix, recurse) + + def forward(self, *args, **kwargs): + self.module.zero_grad(set_to_none=True) + return self.module(*args, **kwargs) + + def backward(self, loss: torch.Tensor): + loss.backward() + torch.cuda.current_stream().wait_stream(self.comm_stream) + for p in self.module.parameters(): + p.grad = p._saved_grad + + def grad_handle(self, p, grad): + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + if self.dp_world_size > 1: + grad = grad / self.dp_world_size + self.comm_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.comm_stream): + dist.all_reduce(grad, group=gpc.get_group(ParallelMode.DATA)) + ColoDDP._save_grad(p, grad) + grad.record_stream(self.comm_stream) + else: + ColoDDP._save_grad(p, grad) + return empty_grad + + @staticmethod + def _save_grad(p, grad): + if hasattr(p, '_saved_grad'): + p._saved_grad.add_(grad) + else: + p._saved_grad = grad + + def zero_grad(self, set_to_none: bool = False) -> None: + self.module.zero_grad(set_to_none=True) + for p in self.module.parameters(): + if getattr(p, '_saved_grad', None) is not None: + if set_to_none: + p._saved_grad = None + else: + if p._saved_grad.grad_fn is not None: + p._saved_grad.detach_() + else: + p._saved_grad.requires_grad_(False) + p._saved_grad.zero_() diff --git a/tests/test_tensor/test_gpt.py b/tests/test_tensor/test_gpt.py index 781e36c25..0ee490f5b 100644 --- a/tests/test_tensor/test_gpt.py +++ b/tests/test_tensor/test_gpt.py @@ -9,8 +9,10 @@ from colossalai.utils import ColoInitContext from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec from colossalai.core import global_context as gpc from functools import partial -from _utils import tensor_equal, tensor_shard_equal +from _utils import tensor_equal, tensor_shard_equal, set_seed from tests.components_to_test.registry import non_distributed_component_funcs +from torch.nn.parallel import DistributedDataParallel as DDP +from colossalai.nn.parallel import ColoDDP def init_1d_row_spec(model): @@ -43,7 +45,7 @@ def check_grad_equal(model, torch_model): assert tensor_shard_equal(torch_p.grad, p.grad) -def run_gpt(init_spec_func): +def run_gpt(init_spec_func, use_ddp): get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -51,37 +53,50 @@ def run_gpt(init_spec_func): model = model_builder() model = model.cuda() torch_model = model_builder().cuda() + if use_ddp: + model = ColoDDP(model) + torch_model = DDP(torch_model, + device_ids=[gpc.get_global_rank()], + process_group=gpc.get_group(ParallelMode.DATA)) for torch_p, p in zip(torch_model.parameters(), model.parameters()): torch_p.data.copy_(p) init_spec_func(model) check_param_equal(model, torch_model) model.train() torch_model.train() + set_seed(gpc.get_local_rank(ParallelMode.DATA)) for i, (input_ids, attn_mask) in enumerate(train_dataloader): logits = model(input_ids, attn_mask) torch_logits = torch_model(input_ids, attn_mask) assert tensor_equal(torch_logits, logits) loss = criterion(logits, input_ids) torch_loss = criterion(torch_logits, input_ids) - loss.backward() + if use_ddp: + model.backward(loss) + else: + loss.backward() torch_loss.backward() check_grad_equal(model, torch_model) if i > 0: break -def run_dist(rank, world_size, port): - config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) +def run_dist(rank, world_size, port, use_ddp): + if use_ddp and world_size == 1: + return + tp_world_size = world_size // 2 if use_ddp else world_size + config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_gpt(init_1d_row_spec) - run_gpt(init_1d_col_spec) + run_gpt(init_1d_row_spec, use_ddp) + run_gpt(init_1d_col_spec, use_ddp) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize('use_ddp', [False, True]) @rerun_if_address_is_in_use() -def test_gpt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) +def test_gpt(world_size, use_ddp): + run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) mp.spawn(run_func, nprocs=world_size)