ColossalAI/colossalai/zero/sharded_param/sharded_tensor.py

67 lines
1.8 KiB
Python
Raw Normal View History

import torch
import torch.distributed as dist
from typing import Optional
class ShardedTensor(object):
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
r"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
self._payload = tensor
self.process_group = process_group
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
self._is_sharded = False
self._origin_shape = tensor.shape
self._origin_numel = tensor.numel()
self._origin_dtype = tensor.dtype
@property
def origin_numel(self):
return self._origin_numel
@property
def origin_shape(self):
return self._origin_shape
@property
def is_sharded(self):
return self._is_sharded
@is_sharded.setter
def is_sharded(self, flag: bool):
self._is_sharded = flag
@property
def payload(self):
return self._payload
def copy_payload(self, tensor):
2022-03-10 09:51:50 +00:00
self._payload.view(-1).copy_(tensor.view(-1))
def reset_payload(self, tensor):
del self._payload
self._payload = tensor
@property
def device(self):
return self._payload.device
@property
def dtype(self):
assert self._payload.dtype == self._origin_dtype
return self._origin_dtype
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
@property
def shape(self):
return self._payload.shape