fix dist spec mgr (#1045)

pull/1047/head
ver217 2022-05-31 12:14:39 +08:00 committed by GitHub
parent 9492a561c3
commit 7faef93326
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 1 deletions

View File

@ -34,7 +34,7 @@ class DistSpecManager:
chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i])
chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size)
idx %= num_parts
return chunk.detach().contiguous()
return chunk.clone().detach().contiguous()
@staticmethod
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor:

View File

@ -33,8 +33,25 @@ def run():
assert torch.equal(x, DistSpecManager._gather(mat_shard, mat_spec))
def check_mem():
group = _get_default_group()
size = dist.get_world_size()
assert torch.cuda.memory_allocated() == 0
x = torch.rand(32, 32).cuda()
orig_mem = x.numel() * x.element_size()
assert torch.cuda.memory_allocated() == orig_mem
old_dist_spec = distspec.replicate()
row_spec = distspec.shard(group, [0], [size])
x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec)
assert x.size(0) == 32 // size and x.size(1) == 32
assert torch.cuda.memory_allocated() == orig_mem // size
x.data = DistSpecManager._gather(x, row_spec)
assert torch.cuda.memory_allocated() == orig_mem
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_mem()
run()