2022-03-01 10:17:01 +00:00
|
|
|
import torch
|
2022-03-28 07:01:21 +00:00
|
|
|
from typing import Optional, Tuple
|
2022-04-24 05:08:48 +00:00
|
|
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
|
|
|
from colossalai.gemini.tensor_utils import colo_tensor_mem_usage
|
|
|
|
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
2022-04-06 08:18:49 +00:00
|
|
|
from typing import List
|
2022-03-04 07:49:23 +00:00
|
|
|
|
2022-04-13 06:54:26 +00:00
|
|
|
EMPTY_TENSOR_DICT = {}
|
|
|
|
|
|
|
|
|
|
|
|
def get_empty_tensor(device: torch.device, dtype: torch.dtype):
|
|
|
|
key = (device, dtype)
|
|
|
|
if key not in EMPTY_TENSOR_DICT:
|
2022-04-18 05:57:03 +00:00
|
|
|
EMPTY_TENSOR_DICT[key] = torch.empty(0, dtype=dtype, device=device)
|
2022-04-13 06:54:26 +00:00
|
|
|
|
|
|
|
return EMPTY_TENSOR_DICT[key]
|
|
|
|
|
2022-03-04 07:49:23 +00:00
|
|
|
|
|
|
|
class ShardedParamV2(object):
|
|
|
|
|
2022-04-13 06:54:26 +00:00
|
|
|
def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None:
|
2022-03-30 07:57:46 +00:00
|
|
|
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
2022-03-30 10:14:50 +00:00
|
|
|
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
2022-03-22 06:56:59 +00:00
|
|
|
# This attribute must be initialized in ShardedModel
|
2022-03-23 06:59:59 +00:00
|
|
|
self.offload_grad: bool = False
|
2022-03-04 07:49:23 +00:00
|
|
|
|
|
|
|
# make sure the shared param is the only owner of payload
|
2022-03-08 06:45:01 +00:00
|
|
|
# The param.data maybe used to init the other part of the model.
|
|
|
|
# For example: File "resnet.py", line 190, in __init__
|
|
|
|
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
|
|
# So we can not empty the .data at this time
|
|
|
|
self.param = param
|
2022-04-13 06:54:26 +00:00
|
|
|
if set_data_none:
|
|
|
|
self.set_data_none()
|
2022-03-08 06:45:01 +00:00
|
|
|
|
2022-04-06 08:18:49 +00:00
|
|
|
def get_payload_tensors(self) -> List[StatefulTensor]:
|
|
|
|
"""returns stateful tensors kept by this class.
|
|
|
|
"""
|
2022-04-08 09:51:34 +00:00
|
|
|
return [self._sharded_data_tensor]
|
2022-04-06 08:18:49 +00:00
|
|
|
|
2022-04-13 06:54:26 +00:00
|
|
|
def set_data_none(self):
|
|
|
|
self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype)
|
|
|
|
|
|
|
|
def set_grad_none(self):
|
|
|
|
self.saved_grad.set_null()
|
2022-03-04 07:49:23 +00:00
|
|
|
|
|
|
|
@property
|
2022-03-22 06:36:16 +00:00
|
|
|
def sharded_data_tensor(self):
|
|
|
|
return self._sharded_data_tensor
|
2022-03-04 07:49:23 +00:00
|
|
|
|
2022-04-13 06:54:26 +00:00
|
|
|
@property
|
|
|
|
def data_payload(self):
|
2022-04-24 05:08:48 +00:00
|
|
|
assert not self.sharded_data_tensor.is_null()
|
2022-04-13 06:54:26 +00:00
|
|
|
return self.sharded_data_tensor.payload
|
|
|
|
|
|
|
|
@property
|
|
|
|
def grad_payload(self):
|
|
|
|
assert not self.saved_grad.is_null()
|
|
|
|
return self.saved_grad.payload
|
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
@property
|
|
|
|
def param_is_sharded(self):
|
2022-04-13 06:54:26 +00:00
|
|
|
return self.sharded_data_tensor.is_sharded
|
|
|
|
|
2022-04-24 05:08:48 +00:00
|
|
|
def data_payload_reset(self, tensor: torch.Tensor):
|
2022-04-13 06:54:26 +00:00
|
|
|
assert type(tensor) is torch.Tensor
|
|
|
|
assert tensor.requires_grad is False
|
2022-04-24 05:08:48 +00:00
|
|
|
self.sharded_data_tensor.payload_reset(tensor)
|
2022-04-13 06:54:26 +00:00
|
|
|
|
2022-04-24 05:08:48 +00:00
|
|
|
def grad_payload_reset(self, tensor: torch.Tensor):
|
2022-04-13 06:54:26 +00:00
|
|
|
assert type(tensor) is torch.Tensor
|
|
|
|
assert tensor.requires_grad is False
|
2022-04-24 05:08:48 +00:00
|
|
|
self.saved_grad.payload_reset(tensor)
|
2022-03-28 07:01:21 +00:00
|
|
|
|
|
|
|
def get_memory_usage(self) -> Tuple[int, int]:
|
|
|
|
"""
|
|
|
|
get the memory usage of the param, including data and grad
|
|
|
|
Returns:
|
|
|
|
Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte
|
|
|
|
"""
|
|
|
|
cuda_mem_use, cpu_mem_use = 0, 0
|
|
|
|
|
|
|
|
def _update_mem_use(t: Optional[torch.Tensor]):
|
|
|
|
if t is None:
|
|
|
|
return
|
|
|
|
assert isinstance(t, torch.Tensor)
|
|
|
|
nonlocal cuda_mem_use
|
|
|
|
nonlocal cpu_mem_use
|
2022-03-29 01:08:18 +00:00
|
|
|
t_cuda, t_cpu = colo_tensor_mem_usage(t)
|
|
|
|
cuda_mem_use += t_cuda
|
|
|
|
cpu_mem_use += t_cpu
|
2022-03-28 07:01:21 +00:00
|
|
|
|
2022-03-28 08:19:19 +00:00
|
|
|
address_set = set()
|
2022-04-13 06:54:26 +00:00
|
|
|
_update_mem_use(self.data_payload)
|
|
|
|
address_set.add(self.data_payload.data_ptr())
|
2022-03-28 08:19:19 +00:00
|
|
|
|
2022-03-30 10:14:50 +00:00
|
|
|
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
|
2022-04-13 06:54:26 +00:00
|
|
|
_update_mem_use(self.grad_payload)
|
2022-03-30 10:14:50 +00:00
|
|
|
address_set.add(self.saved_grad.data_ptr())
|
2022-03-28 08:19:19 +00:00
|
|
|
|
|
|
|
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
|
|
|
_update_mem_use(self.param.data)
|
|
|
|
address_set.add(self.param.data.data_ptr())
|
|
|
|
|
|
|
|
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
|
|
|
|
_update_mem_use(self.param.grad)
|
2022-03-28 07:01:21 +00:00
|
|
|
|
|
|
|
return cuda_mem_use, cpu_mem_use
|