2022-03-01 10:17:01 +00:00
|
|
|
import torch
|
2022-03-02 10:28:29 +00:00
|
|
|
import torch.distributed as dist
|
2022-03-04 07:49:23 +00:00
|
|
|
from colossalai.zero.sharded_param import ShardedTensor
|
2022-03-28 07:01:21 +00:00
|
|
|
from typing import Optional, Tuple
|
2022-03-04 07:49:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ShardedParamV2(object):
|
|
|
|
|
2022-03-08 06:45:01 +00:00
|
|
|
def __init__(self,
|
|
|
|
param: torch.nn.Parameter,
|
|
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
|
|
rm_torch_payload=False) -> None:
|
2022-03-22 06:36:16 +00:00
|
|
|
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
2022-03-15 09:07:35 +00:00
|
|
|
self.fp16_grad: Optional[torch.Tensor] = None
|
|
|
|
self.fp32_grad: Optional[torch.Tensor] = None
|
2022-03-22 06:56:59 +00:00
|
|
|
# This attribute must be initialized in ShardedModel
|
2022-03-23 06:59:59 +00:00
|
|
|
self.offload_grad: bool = False
|
2022-03-04 07:49:23 +00:00
|
|
|
|
|
|
|
# make sure the shared param is the only owner of payload
|
2022-03-08 06:45:01 +00:00
|
|
|
# The param.data maybe used to init the other part of the model.
|
|
|
|
# For example: File "resnet.py", line 190, in __init__
|
|
|
|
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
|
|
# So we can not empty the .data at this time
|
|
|
|
self.param = param
|
|
|
|
if rm_torch_payload:
|
|
|
|
self.remove_torch_payload()
|
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
# Backward count for handle local grad accumulation
|
|
|
|
# This value will increment by 1 in every pre-bwd hook
|
|
|
|
# And will be reset to 0 in every final-bwd hook
|
|
|
|
self.bwd_count = 0
|
|
|
|
|
2022-03-08 06:45:01 +00:00
|
|
|
def remove_torch_payload(self):
|
|
|
|
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
2022-03-04 07:49:23 +00:00
|
|
|
|
|
|
|
@property
|
2022-03-22 06:36:16 +00:00
|
|
|
def sharded_data_tensor(self):
|
|
|
|
return self._sharded_data_tensor
|
2022-03-04 07:49:23 +00:00
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
@property
|
|
|
|
def param_is_sharded(self):
|
2022-03-22 06:36:16 +00:00
|
|
|
return self._sharded_data_tensor.is_sharded
|
2022-03-28 07:01:21 +00:00
|
|
|
|
|
|
|
def get_memory_usage(self) -> Tuple[int, int]:
|
|
|
|
"""
|
|
|
|
get the memory usage of the param, including data and grad
|
|
|
|
Returns:
|
|
|
|
Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte
|
|
|
|
"""
|
|
|
|
cuda_mem_use, cpu_mem_use = 0, 0
|
|
|
|
|
|
|
|
def _update_mem_use(t: Optional[torch.Tensor]):
|
|
|
|
if t is None:
|
|
|
|
return
|
|
|
|
assert isinstance(t, torch.Tensor)
|
|
|
|
nonlocal cuda_mem_use
|
|
|
|
nonlocal cpu_mem_use
|
|
|
|
if t.device.type == 'cpu':
|
|
|
|
cpu_mem_use += t.numel() * t.element_size()
|
|
|
|
elif t.device.type == 'cuda':
|
|
|
|
cuda_mem_use += t.numel() * t.element_size()
|
|
|
|
|
2022-03-28 08:19:19 +00:00
|
|
|
address_set = set()
|
2022-03-28 07:01:21 +00:00
|
|
|
_update_mem_use(self.sharded_data_tensor.payload)
|
2022-03-28 08:19:19 +00:00
|
|
|
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
|
|
|
|
|
|
|
if self.fp16_grad is not None and self.fp16_grad.data_ptr() not in address_set:
|
|
|
|
_update_mem_use(self.fp16_grad)
|
|
|
|
address_set.add(self.fp16_grad.data_ptr())
|
|
|
|
|
|
|
|
if self.fp32_grad is not None and self.fp32_grad.data_ptr() not in address_set:
|
|
|
|
_update_mem_use(self.fp32_grad)
|
|
|
|
address_set.add(self.fp32_grad.data_ptr())
|
|
|
|
|
|
|
|
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
|
|
|
_update_mem_use(self.param.data)
|
|
|
|
address_set.add(self.param.data.data_ptr())
|
|
|
|
|
|
|
|
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
|
|
|
|
_update_mem_use(self.param.grad)
|
|
|
|
address_set.add(self.param.grad.data_ptr())
|
2022-03-28 07:01:21 +00:00
|
|
|
|
|
|
|
return cuda_mem_use, cpu_mem_use
|