ColossalAI/colossalai/utils/memory_tracer/commons.py

12 lines
311 B
Python

from colossalai.zero.sharded_param import ShardedTensor
from typing import Union
import torch
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
if isinstance(t, ShardedTensor):
target = t.payload
else:
target = t
return target.numel() * target.element_size()