2022-03-04 02:46:13 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
class ShardedTensor(object):
|
|
|
|
|
|
|
|
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
|
|
|
r"""
|
2022-03-10 03:20:04 +00:00
|
|
|
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
2022-03-04 02:46:13 +00:00
|
|
|
"""
|
|
|
|
self._payload = tensor
|
|
|
|
self.process_group = process_group
|
|
|
|
self.world_size = dist.get_world_size(self.process_group)
|
|
|
|
self.local_rank = dist.get_rank(self.process_group)
|
|
|
|
self._is_sharded = False
|
|
|
|
|
|
|
|
self._origin_shape = tensor.shape
|
|
|
|
self._origin_numel = tensor.numel()
|
|
|
|
self._origin_dtype = tensor.dtype
|
|
|
|
|
2022-03-04 07:35:07 +00:00
|
|
|
@property
|
|
|
|
def origin_numel(self):
|
|
|
|
return self._origin_numel
|
|
|
|
|
|
|
|
@property
|
|
|
|
def origin_shape(self):
|
|
|
|
return self._origin_shape
|
|
|
|
|
2022-03-04 02:46:13 +00:00
|
|
|
@property
|
|
|
|
def is_sharded(self):
|
|
|
|
return self._is_sharded
|
|
|
|
|
2022-03-04 07:35:07 +00:00
|
|
|
@is_sharded.setter
|
|
|
|
def is_sharded(self, flag: bool):
|
|
|
|
self._is_sharded = flag
|
|
|
|
|
2022-03-04 02:46:13 +00:00
|
|
|
@property
|
|
|
|
def payload(self):
|
|
|
|
return self._payload
|
|
|
|
|
2022-03-04 07:35:07 +00:00
|
|
|
def copy_payload(self, tensor):
|
2022-03-10 09:51:50 +00:00
|
|
|
self._payload.view(-1).copy_(tensor.view(-1))
|
2022-03-04 02:46:13 +00:00
|
|
|
|
2022-03-04 07:35:07 +00:00
|
|
|
def reset_payload(self, tensor):
|
|
|
|
del self._payload
|
|
|
|
self._payload = tensor
|
|
|
|
|
2022-03-10 06:08:58 +00:00
|
|
|
@property
|
|
|
|
def device(self):
|
|
|
|
return self._payload.device
|
|
|
|
|
2022-03-04 02:46:13 +00:00
|
|
|
@property
|
|
|
|
def dtype(self):
|
2022-03-04 07:35:07 +00:00
|
|
|
assert self._payload.dtype == self._origin_dtype
|
2022-03-04 02:46:13 +00:00
|
|
|
return self._origin_dtype
|
|
|
|
|
2022-03-10 06:08:58 +00:00
|
|
|
def to(self, device: torch.device):
|
|
|
|
self._payload = self._payload.to(device)
|
|
|
|
|
2022-03-04 02:46:13 +00:00
|
|
|
@property
|
|
|
|
def shape(self):
|
|
|
|
return self._payload.shape
|