2022-03-01 10:17:01 +00:00
|
|
|
import torch
|
2022-03-02 10:28:29 +00:00
|
|
|
import torch.distributed as dist
|
2022-03-04 07:49:23 +00:00
|
|
|
from colossalai.zero.sharded_param import ShardedTensor
|
2022-03-22 06:36:16 +00:00
|
|
|
from typing import Optional
|
2022-03-04 07:49:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ShardedParamV2(object):
|
|
|
|
|
2022-03-08 06:45:01 +00:00
|
|
|
def __init__(self,
|
|
|
|
param: torch.nn.Parameter,
|
|
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
|
|
rm_torch_payload=False) -> None:
|
2022-03-22 06:36:16 +00:00
|
|
|
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
2022-03-15 09:07:35 +00:00
|
|
|
self.fp16_grad: Optional[torch.Tensor] = None
|
|
|
|
self.fp32_grad: Optional[torch.Tensor] = None
|
2022-03-04 07:49:23 +00:00
|
|
|
|
|
|
|
# make sure the shared param is the only owner of payload
|
2022-03-08 06:45:01 +00:00
|
|
|
# The param.data maybe used to init the other part of the model.
|
|
|
|
# For example: File "resnet.py", line 190, in __init__
|
|
|
|
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
|
|
# So we can not empty the .data at this time
|
|
|
|
self.param = param
|
|
|
|
if rm_torch_payload:
|
|
|
|
self.remove_torch_payload()
|
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
# Backward count for handle local grad accumulation
|
|
|
|
# This value will increment by 1 in every pre-bwd hook
|
|
|
|
# And will be reset to 0 in every final-bwd hook
|
|
|
|
self.bwd_count = 0
|
|
|
|
|
2022-03-08 06:45:01 +00:00
|
|
|
def remove_torch_payload(self):
|
|
|
|
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
2022-03-04 07:49:23 +00:00
|
|
|
|
|
|
|
@property
|
2022-03-22 06:36:16 +00:00
|
|
|
def sharded_data_tensor(self):
|
|
|
|
return self._sharded_data_tensor
|
2022-03-04 07:49:23 +00:00
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
@property
|
|
|
|
def param_is_sharded(self):
|
2022-03-22 06:36:16 +00:00
|
|
|
return self._sharded_data_tensor.is_sharded
|