From 87775a068288de30fd8235e9c91ec8bfe5fb7eee Mon Sep 17 00:00:00 2001 From: HELSON Date: Tue, 26 Jul 2022 14:13:38 +0800 Subject: [PATCH] [colotensor] use cpu memory to store state_dict (#1367) --- colossalai/nn/parallel/data_parallel.py | 3 ++- colossalai/utils/checkpoint/utils.py | 18 ++++++++++++++++-- tests/test_ddp/test_ddp_state_dict.py | 8 +++++++- .../test_tensor/test_colo_checkpoint_tools.py | 2 +- 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index ab5a9fea7..da05df1cb 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -318,7 +318,8 @@ class ZeroDDP(ColoDDP): self.chunk_manager.access_chunk(chunk) for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): if p is not None: - destination[prefix + name] = fp32_p.clone() if keep_vars else fp32_p.clone().detach() + rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu() + destination[prefix + name] = rec_p if keep_vars else rec_p.detach() for chunk in chunks: self.chunk_manager.release_chunk(chunk) for name, buf in self.named_buffers(): diff --git a/colossalai/utils/checkpoint/utils.py b/colossalai/utils/checkpoint/utils.py index 3b8b83c15..a9e0e7edd 100644 --- a/colossalai/utils/checkpoint/utils.py +++ b/colossalai/utils/checkpoint/utils.py @@ -4,6 +4,20 @@ from colossalai.tensor import ColoTensor, ColoTensorSpec from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern +def robust_broadcast(tensor): + with torch.no_grad(): + is_cpu_ten = tensor.device.type == 'cpu' + if is_cpu_ten: + b_data = tensor.cuda() + else: + b_data = tensor + + dist.broadcast(b_data, 0) + + if is_cpu_ten: + tensor.copy_(b_data) + + def gather_tensor(colo_tensor: ColoTensor) -> None: """Make colo_tensor replicated when the rank is 0 """ @@ -27,7 +41,7 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: """Reversal operation of `gather_tensor`. """ if dist_spec.placement == DistPlacementPattern.REPLICATE: - dist.broadcast(colo_tensor.data, 0) + robust_broadcast(colo_tensor.data) else: global_size = colo_tensor.size_global() @@ -35,7 +49,7 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: entire_data = colo_tensor.data else: entire_data = torch.empty(global_size, device=colo_tensor.device) - dist.broadcast(entire_data, 0) + robust_broadcast(entire_data) if dist.get_rank() == 0: colo_tensor.set_dist_spec(dist_spec) diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index 121a8f44e..3f3c316a8 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -19,7 +19,13 @@ from colossalai.tensor import ProcessGroup, ColoParameter def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()): assert k1 == k2 - assert torch.allclose(t1, t2, atol=1e-3, rtol=1e-3) + + if t1.device != t2.device: + temp_t2 = t2.to(t1.device) + else: + temp_t2 = t2 + + assert torch.allclose(t1, temp_t2, atol=1e-3, rtol=1e-3) def init_ddp(module: torch.nn.Module) -> ColoDDP: diff --git a/tests/test_tensor/test_colo_checkpoint_tools.py b/tests/test_tensor/test_colo_checkpoint_tools.py index ec80c06e8..aa333d552 100644 --- a/tests/test_tensor/test_colo_checkpoint_tools.py +++ b/tests/test_tensor/test_colo_checkpoint_tools.py @@ -17,7 +17,7 @@ from tests.test_tensor.common_utils import tensor_shard_equal def run_dist(rank, world_size, port, dp_degree, tp_degree): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree) - x = torch.randn(4, 4, device=get_current_device()) + x = torch.randn(4, 4) param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg)) spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D) param.set_tensor_spec(*spec)