diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 5826dde96..6334799d0 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -1,7 +1,7 @@ import torch import torch.distributed as dist from colossalai.zero.sharded_param import ShardedTensor -from typing import Optional +from typing import Optional, Tuple class ShardedParamV2(object): @@ -40,3 +40,28 @@ class ShardedParamV2(object): @property def param_is_sharded(self): return self._sharded_data_tensor.is_sharded + + def get_memory_usage(self) -> Tuple[int, int]: + """ + get the memory usage of the param, including data and grad + Returns: + Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte + """ + cuda_mem_use, cpu_mem_use = 0, 0 + + def _update_mem_use(t: Optional[torch.Tensor]): + if t is None: + return + assert isinstance(t, torch.Tensor) + nonlocal cuda_mem_use + nonlocal cpu_mem_use + if t.device.type == 'cpu': + cpu_mem_use += t.numel() * t.element_size() + elif t.device.type == 'cuda': + cuda_mem_use += t.numel() * t.element_size() + + _update_mem_use(self.sharded_data_tensor.payload) + _update_mem_use(self.fp16_grad) + _update_mem_use(self.fp32_grad) + + return cuda_mem_use, cpu_mem_use diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index 92d39a4e6..de48f1f9c 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -54,6 +54,24 @@ def _run_shard_param_v2(rank, world_size, port): sparam.remove_torch_payload() assert (param.data.numel() == 1) + # Test get memory usage + sparam.fp32_grad = torch.randn(2, 3) + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2 + + sparam.fp16_grad = torch.randn(2, 3).cuda().half() + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2 + assert cuda_mem_use == 2 * 3 * 2 + + sparam.fp16_grad = None + sparam.fp32_grad = torch.randn(2, 3) + sparam.remove_torch_payload() + cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() + assert cpu_mem_use == 2 * 3 * 4 * 2 + assert cuda_mem_use == 0 + print(f'cuda_mem_use {cuda_mem_use} cpu_mem_use {cpu_mem_use}') + @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) @@ -64,5 +82,5 @@ def test_shard_param_v2(world_size): if __name__ == '__main__': - test_shard_tensor(2) + # test_shard_tensor(2) test_shard_param_v2(2)