mirror of https://github.com/hpcaitech/ColossalAI
[zero] get memory usage for sharded param (#536)
parent
56ad945797
commit
37cb70feec
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.zero.sharded_param import ShardedTensor
|
from colossalai.zero.sharded_param import ShardedTensor
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
class ShardedParamV2(object):
|
class ShardedParamV2(object):
|
||||||
|
@ -40,3 +40,28 @@ class ShardedParamV2(object):
|
||||||
@property
|
@property
|
||||||
def param_is_sharded(self):
|
def param_is_sharded(self):
|
||||||
return self._sharded_data_tensor.is_sharded
|
return self._sharded_data_tensor.is_sharded
|
||||||
|
|
||||||
|
def get_memory_usage(self) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
get the memory usage of the param, including data and grad
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte
|
||||||
|
"""
|
||||||
|
cuda_mem_use, cpu_mem_use = 0, 0
|
||||||
|
|
||||||
|
def _update_mem_use(t: Optional[torch.Tensor]):
|
||||||
|
if t is None:
|
||||||
|
return
|
||||||
|
assert isinstance(t, torch.Tensor)
|
||||||
|
nonlocal cuda_mem_use
|
||||||
|
nonlocal cpu_mem_use
|
||||||
|
if t.device.type == 'cpu':
|
||||||
|
cpu_mem_use += t.numel() * t.element_size()
|
||||||
|
elif t.device.type == 'cuda':
|
||||||
|
cuda_mem_use += t.numel() * t.element_size()
|
||||||
|
|
||||||
|
_update_mem_use(self.sharded_data_tensor.payload)
|
||||||
|
_update_mem_use(self.fp16_grad)
|
||||||
|
_update_mem_use(self.fp32_grad)
|
||||||
|
|
||||||
|
return cuda_mem_use, cpu_mem_use
|
||||||
|
|
|
@ -54,6 +54,24 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||||
sparam.remove_torch_payload()
|
sparam.remove_torch_payload()
|
||||||
assert (param.data.numel() == 1)
|
assert (param.data.numel() == 1)
|
||||||
|
|
||||||
|
# Test get memory usage
|
||||||
|
sparam.fp32_grad = torch.randn(2, 3)
|
||||||
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
|
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||||
|
|
||||||
|
sparam.fp16_grad = torch.randn(2, 3).cuda().half()
|
||||||
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
|
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||||
|
assert cuda_mem_use == 2 * 3 * 2
|
||||||
|
|
||||||
|
sparam.fp16_grad = None
|
||||||
|
sparam.fp32_grad = torch.randn(2, 3)
|
||||||
|
sparam.remove_torch_payload()
|
||||||
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
|
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||||
|
assert cuda_mem_use == 0
|
||||||
|
print(f'cuda_mem_use {cuda_mem_use} cpu_mem_use {cpu_mem_use}')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2])
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
|
@ -64,5 +82,5 @@ def test_shard_param_v2(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_shard_tensor(2)
|
# test_shard_tensor(2)
|
||||||
test_shard_param_v2(2)
|
test_shard_param_v2(2)
|
||||||
|
|
Loading…
Reference in New Issue