Browse Source

[zero] improve the accuracy of get_memory_usage of sharded param (#538)

pull/535/head
Jiarui Fang 3 years ago committed by GitHub
parent
commit
a590ed0ba3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 20
      colossalai/zero/sharded_param/sharded_param.py
  2. 27
      tests/test_zero_data_parallel/test_shard_param.py

20
colossalai/zero/sharded_param/sharded_param.py

@ -60,8 +60,24 @@ class ShardedParamV2(object):
elif t.device.type == 'cuda':
cuda_mem_use += t.numel() * t.element_size()
address_set = set()
_update_mem_use(self.sharded_data_tensor.payload)
_update_mem_use(self.fp16_grad)
_update_mem_use(self.fp32_grad)
address_set.add(self.sharded_data_tensor.payload.data_ptr())
if self.fp16_grad is not None and self.fp16_grad.data_ptr() not in address_set:
_update_mem_use(self.fp16_grad)
address_set.add(self.fp16_grad.data_ptr())
if self.fp32_grad is not None and self.fp32_grad.data_ptr() not in address_set:
_update_mem_use(self.fp32_grad)
address_set.add(self.fp32_grad.data_ptr())
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
_update_mem_use(self.param.data)
address_set.add(self.param.data.data_ptr())
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
_update_mem_use(self.param.grad)
address_set.add(self.param.grad.data_ptr())
return cuda_mem_use, cpu_mem_use

27
tests/test_zero_data_parallel/test_shard_param.py

@ -51,26 +51,41 @@ def _run_shard_param_v2(rank, world_size, port):
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
sparam.remove_torch_payload()
assert (param.data.numel() == 1)
# Test get memory usage
sparam.fp32_grad = torch.randn(2, 3)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
sparam.remove_torch_payload()
assert (param.data.numel() == 1)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
# 4 is size of dummy tensor of param.data
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
sparam.fp16_grad = torch.randn(2, 3).cuda().half()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
assert cuda_mem_use == 2 * 3 * 2
sparam.fp16_grad = None
sparam.fp32_grad = torch.randn(2, 3)
sparam.remove_torch_payload()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
assert cuda_mem_use == 0
# append a grad to torch param
param.data = sparam.sharded_data_tensor.payload
param.grad = torch.randn(2, 3)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}"
assert cuda_mem_use == 0
# reuse torch grad for sparam
sparam.fp32_grad = param.grad
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cuda_mem_use == 0
print(f'cuda_mem_use {cuda_mem_use} cpu_mem_use {cpu_mem_use}')
@pytest.mark.dist

Loading…
Cancel
Save