diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 6334799d0..e73f6cc7c 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -60,8 +60,24 @@ class ShardedParamV2(object): elif t.device.type == 'cuda': cuda_mem_use += t.numel() * t.element_size() + address_set = set() _update_mem_use(self.sharded_data_tensor.payload) - _update_mem_use(self.fp16_grad) - _update_mem_use(self.fp32_grad) + address_set.add(self.sharded_data_tensor.payload.data_ptr()) + + if self.fp16_grad is not None and self.fp16_grad.data_ptr() not in address_set: + _update_mem_use(self.fp16_grad) + address_set.add(self.fp16_grad.data_ptr()) + + if self.fp32_grad is not None and self.fp32_grad.data_ptr() not in address_set: + _update_mem_use(self.fp32_grad) + address_set.add(self.fp32_grad.data_ptr()) + + if self.param.data is not None and self.param.data.data_ptr() not in address_set: + _update_mem_use(self.param.data) + address_set.add(self.param.data.data_ptr()) + + if self.param.grad is not None and self.param.grad.data_ptr() not in address_set: + _update_mem_use(self.param.grad) + address_set.add(self.param.grad.data_ptr()) return cuda_mem_use, cpu_mem_use diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index de48f1f9c..e2694baf7 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -51,26 +51,41 @@ def _run_shard_param_v2(rank, world_size, port): allclose(sparam.sharded_data_tensor.payload, param_ref.data) - sparam.remove_torch_payload() - assert (param.data.numel() == 1) - # Test get memory usage sparam.fp32_grad = torch.randn(2, 3) cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 + assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}" + + sparam.remove_torch_payload() + assert (param.data.numel() == 1) + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + # 4 is size of dummy tensor of param.data + assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 sparam.fp16_grad = torch.randn(2, 3).cuda().half() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 + assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 assert cuda_mem_use == 2 * 3 * 2 sparam.fp16_grad = None sparam.fp32_grad = torch.randn(2, 3) sparam.remove_torch_payload() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 + assert cuda_mem_use == 0 + + # append a grad to torch param + param.data = sparam.sharded_data_tensor.payload + param.grad = torch.randn(2, 3) + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}" + assert cuda_mem_use == 0 + + # reuse torch grad for sparam + sparam.fp32_grad = param.grad + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() assert cpu_mem_use == 2 * 3 * 4 * 2 assert cuda_mem_use == 0 - print(f'cuda_mem_use {cuda_mem_use} cpu_mem_use {cpu_mem_use}') @pytest.mark.dist