|
|
|
@ -10,10 +10,20 @@ from typing import List
|
|
|
|
|
# empty tensor is expected to raise error when get used |
|
|
|
|
FAKE_EMPTY_TENSOR = torch.BoolTensor([], device='cpu') |
|
|
|
|
|
|
|
|
|
EMPTY_TENSOR_DICT = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_empty_tensor(device: torch.device, dtype: torch.dtype): |
|
|
|
|
key = (device, dtype) |
|
|
|
|
if key not in EMPTY_TENSOR_DICT: |
|
|
|
|
EMPTY_TENSOR_DICT[key] = FAKE_EMPTY_TENSOR.to(device, dtype) |
|
|
|
|
|
|
|
|
|
return EMPTY_TENSOR_DICT[key] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShardedParamV2(object): |
|
|
|
|
|
|
|
|
|
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None: |
|
|
|
|
def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None: |
|
|
|
|
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data) |
|
|
|
|
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) |
|
|
|
|
# This attribute must be initialized in ShardedModel |
|
|
|
@ -25,24 +35,47 @@ class ShardedParamV2(object):
|
|
|
|
|
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
|
|
|
# So we can not empty the .data at this time |
|
|
|
|
self.param = param |
|
|
|
|
if rm_torch_payload: |
|
|
|
|
self.remove_torch_payload() |
|
|
|
|
if set_data_none: |
|
|
|
|
self.set_data_none() |
|
|
|
|
|
|
|
|
|
def get_payload_tensors(self) -> List[StatefulTensor]: |
|
|
|
|
"""returns stateful tensors kept by this class. |
|
|
|
|
""" |
|
|
|
|
return [self._sharded_data_tensor] |
|
|
|
|
|
|
|
|
|
def remove_torch_payload(self): |
|
|
|
|
self.param.data = FAKE_EMPTY_TENSOR.to(self._sharded_data_tensor.device, self._sharded_data_tensor.dtype) |
|
|
|
|
def set_data_none(self): |
|
|
|
|
self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype) |
|
|
|
|
|
|
|
|
|
def set_grad_none(self): |
|
|
|
|
self.saved_grad.set_null() |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def sharded_data_tensor(self): |
|
|
|
|
return self._sharded_data_tensor |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def data_payload(self): |
|
|
|
|
return self.sharded_data_tensor.payload |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def grad_payload(self): |
|
|
|
|
assert not self.saved_grad.is_null() |
|
|
|
|
return self.saved_grad.payload |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def param_is_sharded(self): |
|
|
|
|
return self._sharded_data_tensor.is_sharded |
|
|
|
|
return self.sharded_data_tensor.is_sharded |
|
|
|
|
|
|
|
|
|
def reset_data_payload(self, tensor: torch.Tensor): |
|
|
|
|
assert type(tensor) is torch.Tensor |
|
|
|
|
assert tensor.requires_grad is False |
|
|
|
|
self.sharded_data_tensor.reset_payload(tensor) |
|
|
|
|
self.set_data_none() |
|
|
|
|
|
|
|
|
|
def reset_grad_payload(self, tensor: torch.Tensor): |
|
|
|
|
assert type(tensor) is torch.Tensor |
|
|
|
|
assert tensor.requires_grad is False |
|
|
|
|
self.saved_grad.reset_payload(tensor) |
|
|
|
|
|
|
|
|
|
def get_memory_usage(self) -> Tuple[int, int]: |
|
|
|
|
""" |
|
|
|
@ -63,11 +96,11 @@ class ShardedParamV2(object):
|
|
|
|
|
cpu_mem_use += t_cpu |
|
|
|
|
|
|
|
|
|
address_set = set() |
|
|
|
|
_update_mem_use(self.sharded_data_tensor.payload) |
|
|
|
|
address_set.add(self.sharded_data_tensor.payload.data_ptr()) |
|
|
|
|
_update_mem_use(self.data_payload) |
|
|
|
|
address_set.add(self.data_payload.data_ptr()) |
|
|
|
|
|
|
|
|
|
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set: |
|
|
|
|
_update_mem_use(self.saved_grad.payload) |
|
|
|
|
_update_mem_use(self.grad_payload) |
|
|
|
|
address_set.add(self.saved_grad.data_ptr()) |
|
|
|
|
|
|
|
|
|
if self.param.data is not None and self.param.data.data_ptr() not in address_set: |
|
|
|
|