mirror of https://github.com/hpcaitech/ColossalAI
[zero] improve the accuracy of get_memory_usage of sharded param (#538)
parent
37cb70feec
commit
a590ed0ba3
|
@ -60,8 +60,24 @@ class ShardedParamV2(object):
|
|||
elif t.device.type == 'cuda':
|
||||
cuda_mem_use += t.numel() * t.element_size()
|
||||
|
||||
address_set = set()
|
||||
_update_mem_use(self.sharded_data_tensor.payload)
|
||||
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
||||
|
||||
if self.fp16_grad is not None and self.fp16_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.fp16_grad)
|
||||
address_set.add(self.fp16_grad.data_ptr())
|
||||
|
||||
if self.fp32_grad is not None and self.fp32_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.fp32_grad)
|
||||
address_set.add(self.fp32_grad.data_ptr())
|
||||
|
||||
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
||||
_update_mem_use(self.param.data)
|
||||
address_set.add(self.param.data.data_ptr())
|
||||
|
||||
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.param.grad)
|
||||
address_set.add(self.param.grad.data_ptr())
|
||||
|
||||
return cuda_mem_use, cpu_mem_use
|
||||
|
|
|
@ -51,26 +51,41 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
|
||||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||
|
||||
sparam.remove_torch_payload()
|
||||
assert (param.data.numel() == 1)
|
||||
|
||||
# Test get memory usage
|
||||
sparam.fp32_grad = torch.randn(2, 3)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
||||
|
||||
sparam.remove_torch_payload()
|
||||
assert (param.data.numel() == 1)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
# 4 is size of dummy tensor of param.data
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
|
||||
sparam.fp16_grad = torch.randn(2, 3).cuda().half()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
assert cuda_mem_use == 2 * 3 * 2
|
||||
|
||||
sparam.fp16_grad = None
|
||||
sparam.fp32_grad = torch.randn(2, 3)
|
||||
sparam.remove_torch_payload()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
# append a grad to torch param
|
||||
param.data = sparam.sharded_data_tensor.payload
|
||||
param.grad = torch.randn(2, 3)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}"
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
# reuse torch grad for sparam
|
||||
sparam.fp32_grad = param.grad
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 0
|
||||
print(f'cuda_mem_use {cuda_mem_use} cpu_mem_use {cpu_mem_use}')
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue