mirror of https://github.com/hpcaitech/ColossalAI
[zero] improve the accuracy of get_memory_usage of sharded param (#538)
parent
37cb70feec
commit
a590ed0ba3
|
@ -60,8 +60,24 @@ class ShardedParamV2(object):
|
||||||
elif t.device.type == 'cuda':
|
elif t.device.type == 'cuda':
|
||||||
cuda_mem_use += t.numel() * t.element_size()
|
cuda_mem_use += t.numel() * t.element_size()
|
||||||
|
|
||||||
|
address_set = set()
|
||||||
_update_mem_use(self.sharded_data_tensor.payload)
|
_update_mem_use(self.sharded_data_tensor.payload)
|
||||||
_update_mem_use(self.fp16_grad)
|
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
||||||
_update_mem_use(self.fp32_grad)
|
|
||||||
|
if self.fp16_grad is not None and self.fp16_grad.data_ptr() not in address_set:
|
||||||
|
_update_mem_use(self.fp16_grad)
|
||||||
|
address_set.add(self.fp16_grad.data_ptr())
|
||||||
|
|
||||||
|
if self.fp32_grad is not None and self.fp32_grad.data_ptr() not in address_set:
|
||||||
|
_update_mem_use(self.fp32_grad)
|
||||||
|
address_set.add(self.fp32_grad.data_ptr())
|
||||||
|
|
||||||
|
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
||||||
|
_update_mem_use(self.param.data)
|
||||||
|
address_set.add(self.param.data.data_ptr())
|
||||||
|
|
||||||
|
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
|
||||||
|
_update_mem_use(self.param.grad)
|
||||||
|
address_set.add(self.param.grad.data_ptr())
|
||||||
|
|
||||||
return cuda_mem_use, cpu_mem_use
|
return cuda_mem_use, cpu_mem_use
|
||||||
|
|
|
@ -51,26 +51,41 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||||
|
|
||||||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||||
|
|
||||||
sparam.remove_torch_payload()
|
|
||||||
assert (param.data.numel() == 1)
|
|
||||||
|
|
||||||
# Test get memory usage
|
# Test get memory usage
|
||||||
sparam.fp32_grad = torch.randn(2, 3)
|
sparam.fp32_grad = torch.randn(2, 3)
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
||||||
|
|
||||||
|
sparam.remove_torch_payload()
|
||||||
|
assert (param.data.numel() == 1)
|
||||||
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
|
# 4 is size of dummy tensor of param.data
|
||||||
|
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||||
|
|
||||||
sparam.fp16_grad = torch.randn(2, 3).cuda().half()
|
sparam.fp16_grad = torch.randn(2, 3).cuda().half()
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||||
assert cuda_mem_use == 2 * 3 * 2
|
assert cuda_mem_use == 2 * 3 * 2
|
||||||
|
|
||||||
sparam.fp16_grad = None
|
sparam.fp16_grad = None
|
||||||
sparam.fp32_grad = torch.randn(2, 3)
|
sparam.fp32_grad = torch.randn(2, 3)
|
||||||
sparam.remove_torch_payload()
|
sparam.remove_torch_payload()
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
|
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||||
|
assert cuda_mem_use == 0
|
||||||
|
|
||||||
|
# append a grad to torch param
|
||||||
|
param.data = sparam.sharded_data_tensor.payload
|
||||||
|
param.grad = torch.randn(2, 3)
|
||||||
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
|
assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}"
|
||||||
|
assert cuda_mem_use == 0
|
||||||
|
|
||||||
|
# reuse torch grad for sparam
|
||||||
|
sparam.fp32_grad = param.grad
|
||||||
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||||
assert cuda_mem_use == 0
|
assert cuda_mem_use == 0
|
||||||
print(f'cuda_mem_use {cuda_mem_use} cpu_mem_use {cpu_mem_use}')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
Loading…
Reference in New Issue