diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py index c82524836..a3e229899 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/tensor/dist_spec_mgr.py @@ -34,7 +34,7 @@ class DistSpecManager: chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i]) chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size) idx %= num_parts - return chunk.detach().contiguous() + return chunk.clone().detach().contiguous() @staticmethod def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor: diff --git a/tests/test_tensor/test_dist_spec_mgr.py b/tests/test_tensor/test_dist_spec_mgr.py index ada77faef..f21790da1 100644 --- a/tests/test_tensor/test_dist_spec_mgr.py +++ b/tests/test_tensor/test_dist_spec_mgr.py @@ -33,8 +33,25 @@ def run(): assert torch.equal(x, DistSpecManager._gather(mat_shard, mat_spec)) +def check_mem(): + group = _get_default_group() + size = dist.get_world_size() + assert torch.cuda.memory_allocated() == 0 + x = torch.rand(32, 32).cuda() + orig_mem = x.numel() * x.element_size() + assert torch.cuda.memory_allocated() == orig_mem + old_dist_spec = distspec.replicate() + row_spec = distspec.shard(group, [0], [size]) + x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec) + assert x.size(0) == 32 // size and x.size(1) == 32 + assert torch.cuda.memory_allocated() == orig_mem // size + x.data = DistSpecManager._gather(x, row_spec) + assert torch.cuda.memory_allocated() == orig_mem + + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_mem() run()