2023-01-28 06:35:25 +00:00
|
|
|
import math
|
2022-05-13 07:13:52 +00:00
|
|
|
from copy import copy
|
2022-07-11 03:41:29 +00:00
|
|
|
from functools import lru_cache
|
2022-11-02 08:11:34 +00:00
|
|
|
from typing import Callable, Optional, Set
|
2022-06-21 10:28:38 +00:00
|
|
|
|
2022-11-02 08:11:34 +00:00
|
|
|
import torch
|
|
|
|
|
2022-05-13 07:13:52 +00:00
|
|
|
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
2022-11-08 09:03:50 +00:00
|
|
|
from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec
|
|
|
|
from colossalai.tensor.process_group import ProcessGroup
|
|
|
|
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
2022-11-02 08:11:34 +00:00
|
|
|
|
|
|
|
from .const import TensorType
|
|
|
|
from .op_wrapper import _COLOSSAL_OPS
|
2022-07-11 03:41:29 +00:00
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(None)
|
|
|
|
def _get_my_nowrap_functions() -> Set[Callable]:
|
|
|
|
Tensor = torch.Tensor
|
|
|
|
return {
|
|
|
|
Tensor._base.__get__,
|
|
|
|
Tensor.grad.__get__,
|
|
|
|
Tensor._grad.__get__,
|
2022-08-16 01:21:05 +00:00
|
|
|
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
2022-07-11 03:41:29 +00:00
|
|
|
}
|
2022-04-28 06:43:22 +00:00
|
|
|
|
2022-04-26 07:10:47 +00:00
|
|
|
|
2022-07-21 02:53:15 +00:00
|
|
|
def _convert_output(output, colo_spec: ColoTensorSpec):
|
2022-07-07 10:09:18 +00:00
|
|
|
if type(output) == torch.Tensor:
|
2022-07-21 02:53:15 +00:00
|
|
|
return ColoTensor.from_torch_tensor(output, colo_spec)
|
2022-05-19 04:44:59 +00:00
|
|
|
elif isinstance(output, (list, tuple)):
|
2022-07-21 02:53:15 +00:00
|
|
|
return type(output)(_convert_output(o, colo_spec) for o in output)
|
2022-07-07 10:09:18 +00:00
|
|
|
else:
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
2022-07-21 02:53:15 +00:00
|
|
|
def _get_spec_from_args(args, kwargs) -> ColoTensorSpec:
|
2022-07-07 10:09:18 +00:00
|
|
|
for elem in args:
|
|
|
|
if isinstance(elem, ColoTensor):
|
|
|
|
pg = elem.get_process_group()
|
2022-07-21 02:53:15 +00:00
|
|
|
dp = elem.dist_spec
|
|
|
|
return ColoTensorSpec(pg, dp)
|
2022-07-07 10:09:18 +00:00
|
|
|
elif isinstance(elem, (list, tuple)):
|
2022-07-21 02:53:15 +00:00
|
|
|
spec = _get_spec_from_args(elem, {})
|
|
|
|
if spec is not None:
|
|
|
|
return spec
|
2022-07-12 12:46:31 +00:00
|
|
|
for k, v in kwargs.items():
|
2022-07-07 10:09:18 +00:00
|
|
|
if isinstance(v, ColoTensor):
|
|
|
|
pg = v.get_process_group()
|
2022-07-21 02:53:15 +00:00
|
|
|
dp = v.dist_spec
|
|
|
|
return ColoTensorSpec(pg, dp)
|
2022-07-07 10:09:18 +00:00
|
|
|
return None
|
2022-05-19 04:44:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ColoTensor(torch.Tensor):
|
2022-06-21 10:28:38 +00:00
|
|
|
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
2022-08-16 01:21:05 +00:00
|
|
|
|
|
|
|
The Colotensor can be initialized with a PyTorch tensor in the following ways.
|
|
|
|
|
|
|
|
>>> pg = ProcessGroup()
|
2022-12-26 07:03:54 +00:00
|
|
|
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()))
|
2022-08-16 01:21:05 +00:00
|
|
|
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
|
2022-11-02 08:11:34 +00:00
|
|
|
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
|
|
|
>>> dims=[0],
|
2022-08-16 01:21:05 +00:00
|
|
|
>>> num_partitions=[world_size])
|
|
|
|
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
|
|
|
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
2022-11-02 08:11:34 +00:00
|
|
|
|
2022-06-21 10:28:38 +00:00
|
|
|
Args:
|
|
|
|
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
2022-07-11 07:51:48 +00:00
|
|
|
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
2022-04-21 07:40:23 +00:00
|
|
|
"""
|
2022-12-30 15:11:55 +00:00
|
|
|
torch_major = int(torch.__version__.split('.')[0])
|
2022-11-02 08:11:34 +00:00
|
|
|
torch_minor = int(torch.__version__.split('.')[1])
|
2022-04-21 03:42:37 +00:00
|
|
|
|
2022-07-06 08:15:16 +00:00
|
|
|
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
2022-08-16 01:21:05 +00:00
|
|
|
"""
|
2022-06-21 10:28:38 +00:00
|
|
|
The signature of the __new__ has to be consistent with the torch.Tensor.
|
2022-11-02 08:11:34 +00:00
|
|
|
|
2022-06-21 10:28:38 +00:00
|
|
|
Args:
|
|
|
|
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
2022-07-06 08:15:16 +00:00
|
|
|
spec (TensorSpec, optional): the tensor spec of initialization.
|
2022-11-02 08:11:34 +00:00
|
|
|
|
2022-06-21 10:28:38 +00:00
|
|
|
Returns:
|
|
|
|
ColoTensor: a ColoTensor wrappers the data.
|
|
|
|
"""
|
2022-05-19 04:44:59 +00:00
|
|
|
if data is None:
|
|
|
|
data = torch.empty(0)
|
|
|
|
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
|
|
|
|
|
2022-07-06 08:15:16 +00:00
|
|
|
def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
|
|
|
|
# If not set spec, use a DP process group and replicate dist spec
|
2022-07-08 06:18:30 +00:00
|
|
|
if spec is None:
|
2022-07-06 08:15:16 +00:00
|
|
|
self.has_initialized = False
|
2022-07-11 07:51:48 +00:00
|
|
|
self.dist_spec = ReplicaSpec()
|
2022-07-06 08:15:16 +00:00
|
|
|
self.compute_spec = None
|
|
|
|
self.process_group = ProcessGroup()
|
|
|
|
else:
|
|
|
|
self.has_initialized = True
|
|
|
|
self.dist_spec = spec.dist_attr
|
|
|
|
self.compute_spec = spec.compute_attr
|
2022-07-08 06:18:30 +00:00
|
|
|
if spec.pg is None:
|
|
|
|
self.process_group = ProcessGroup()
|
|
|
|
else:
|
|
|
|
self.process_group = spec.pg
|
2022-07-06 08:15:16 +00:00
|
|
|
|
2022-05-06 04:57:14 +00:00
|
|
|
self._type = TensorType.NONMODEL
|
2022-04-24 05:43:12 +00:00
|
|
|
|
2022-06-24 05:08:54 +00:00
|
|
|
def has_compute_spec(self) -> bool:
|
2022-07-06 08:15:16 +00:00
|
|
|
return self.compute_spec is not None
|
2022-04-28 02:55:40 +00:00
|
|
|
|
2022-05-06 04:57:14 +00:00
|
|
|
def is_model_data(self) -> bool:
|
|
|
|
return self._type == TensorType.MODEL
|
2022-04-28 06:43:22 +00:00
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
def get_process_group(self) -> 'ProcessGroup':
|
2022-07-06 08:15:16 +00:00
|
|
|
return self.process_group
|
|
|
|
|
|
|
|
def set_process_group(self, pg: ProcessGroup):
|
2022-11-02 08:11:34 +00:00
|
|
|
"""set_process_group
|
2022-07-06 08:15:16 +00:00
|
|
|
change the pg of the ColoTensor. Note that the valid use cases is limited.
|
2023-01-06 07:44:50 +00:00
|
|
|
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
|
2022-08-16 01:21:05 +00:00
|
|
|
|
2022-07-06 08:15:16 +00:00
|
|
|
Args:
|
|
|
|
pg (ProcessGroup): target pg
|
|
|
|
|
|
|
|
"""
|
|
|
|
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
|
2022-07-15 06:02:32 +00:00
|
|
|
# if the new pg is the same as the old pg, just returns
|
|
|
|
if self.process_group == pg:
|
|
|
|
return
|
2023-01-06 07:44:50 +00:00
|
|
|
assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
|
|
|
|
"Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
|
2022-07-15 06:02:32 +00:00
|
|
|
assert self.dist_spec.placement.value == 'r', \
|
2023-01-06 07:44:50 +00:00
|
|
|
"Can not set_process_group on a ColoTensor whose dist spec is not Replica"
|
2022-07-06 08:15:16 +00:00
|
|
|
|
|
|
|
self.process_group = pg
|
2022-07-04 10:54:37 +00:00
|
|
|
|
|
|
|
def get_tp_world_size(self) -> int:
|
2022-07-06 08:15:16 +00:00
|
|
|
return self.process_group.tp_world_size()
|
|
|
|
|
|
|
|
def set_dist_spec(self, dist_spec: _DistSpec):
|
2022-11-02 08:11:34 +00:00
|
|
|
"""set_dist_spec
|
2022-07-06 08:15:16 +00:00
|
|
|
set dist spec and change the payloads.
|
2022-08-16 01:21:05 +00:00
|
|
|
|
2022-07-06 08:15:16 +00:00
|
|
|
Args:
|
|
|
|
dist_spec (_DistSpec): target dist spec.
|
|
|
|
"""
|
|
|
|
assert isinstance(dist_spec, _DistSpec)
|
2022-07-08 06:18:30 +00:00
|
|
|
assert self.process_group is not None
|
2022-07-11 05:05:44 +00:00
|
|
|
self._redistribute(dist_spec)
|
2022-07-06 08:15:16 +00:00
|
|
|
|
|
|
|
def set_tensor_spec(self, dist_spec, compute_spec):
|
2022-07-12 07:51:06 +00:00
|
|
|
if dist_spec is not None:
|
2022-07-06 08:15:16 +00:00
|
|
|
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
|
|
|
|
self.set_dist_spec(dist_spec)
|
2022-07-12 07:51:06 +00:00
|
|
|
if compute_spec is not None:
|
2022-07-06 08:15:16 +00:00
|
|
|
self.compute_spec = compute_spec
|
|
|
|
|
|
|
|
def has_compute_pattern(self, compute_pattern):
|
|
|
|
return self.compute_spec.compute_pattern == compute_pattern
|
2022-07-04 10:54:37 +00:00
|
|
|
|
2022-04-21 03:42:37 +00:00
|
|
|
@classmethod
|
|
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
2022-05-19 04:44:59 +00:00
|
|
|
if kwargs is None:
|
|
|
|
kwargs = {}
|
|
|
|
|
|
|
|
if not all(issubclass(cls, t) for t in types):
|
|
|
|
return NotImplemented
|
2022-04-21 06:15:48 +00:00
|
|
|
global _COLOSSAL_OPS
|
|
|
|
if func in _COLOSSAL_OPS:
|
2022-05-19 04:44:59 +00:00
|
|
|
func = _COLOSSAL_OPS[func]
|
2022-05-10 08:04:08 +00:00
|
|
|
|
2022-12-30 15:11:55 +00:00
|
|
|
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
|
2022-11-02 08:11:34 +00:00
|
|
|
# in order to trigger pre-op hook in the forward of checkpoint module
|
|
|
|
# we have to capture the `backward` function
|
|
|
|
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
|
|
|
|
if func is torch.Tensor.backward:
|
|
|
|
assert len(args) == 1 # only has 1 paramter
|
|
|
|
backward_tensor = torch.Tensor(args[0])
|
|
|
|
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
|
|
|
|
return backward_tensor.backward(**tensor_kwargs)
|
|
|
|
|
2022-05-19 04:44:59 +00:00
|
|
|
with torch._C.DisableTorchFunction():
|
|
|
|
ret = func(*args, **kwargs)
|
2022-07-11 03:41:29 +00:00
|
|
|
if func in _get_my_nowrap_functions():
|
2022-05-19 04:44:59 +00:00
|
|
|
return ret
|
|
|
|
else:
|
2022-07-21 02:53:15 +00:00
|
|
|
colo_spec = _get_spec_from_args(args, kwargs)
|
|
|
|
return _convert_output(ret, colo_spec)
|
2022-05-10 08:04:08 +00:00
|
|
|
|
2022-05-19 04:44:59 +00:00
|
|
|
def __repr__(self):
|
2022-07-21 02:53:15 +00:00
|
|
|
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
|
2022-04-27 02:57:49 +00:00
|
|
|
|
2022-07-11 05:05:44 +00:00
|
|
|
def _redistribute(self, dist_spec: _DistSpec) -> None:
|
2022-11-02 08:11:34 +00:00
|
|
|
"""_redistribute
|
2022-06-23 08:35:05 +00:00
|
|
|
Note the function will not handle the logic of backward propagation!
|
|
|
|
It is used during model tensor initializations as an internal function.
|
2022-08-16 01:21:05 +00:00
|
|
|
|
2022-06-23 08:35:05 +00:00
|
|
|
Args:
|
|
|
|
dist_spec (_DistSpec): the target dist. spec.
|
|
|
|
"""
|
2022-07-11 03:41:29 +00:00
|
|
|
assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted"
|
2022-05-19 04:44:59 +00:00
|
|
|
with DistSpecManager.no_grad():
|
2022-07-11 03:41:29 +00:00
|
|
|
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
2022-07-06 08:15:16 +00:00
|
|
|
self.dist_spec = dist_spec
|
2022-04-27 02:57:49 +00:00
|
|
|
|
2022-07-12 02:24:05 +00:00
|
|
|
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
|
2022-11-02 08:11:34 +00:00
|
|
|
"""redistribute
|
2022-07-12 02:24:05 +00:00
|
|
|
Redistribute the tensor among processes. The rule is like this:
|
2022-11-02 08:11:34 +00:00
|
|
|
|
2022-08-16 01:21:05 +00:00
|
|
|
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
|
|
|
|
DP process group not changed.
|
2022-11-02 08:11:34 +00:00
|
|
|
|
2022-08-16 01:21:05 +00:00
|
|
|
2. If the pg is not not None and not equal to the current process group.
|
|
|
|
First, convert the tensor as replicated among the TP process group.
|
|
|
|
Second, reset the process group to the new pg.
|
|
|
|
Third, conver the tensor (new replicated both among the tp process group) to the new dist_spec.
|
2022-07-12 02:24:05 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
dist_spec (_DistSpec): the new dist spec.
|
|
|
|
pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
ColoTensor: a redistributed colotensor
|
|
|
|
"""
|
|
|
|
if pg is not None and pg != self.get_process_group():
|
|
|
|
# if the pg is not equal, convert the current tensor to replicated
|
2022-07-13 15:06:12 +00:00
|
|
|
handled = self.redistribute(ReplicaSpec())
|
|
|
|
else:
|
|
|
|
handled = self
|
|
|
|
pg = self.process_group
|
|
|
|
|
|
|
|
ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg)
|
|
|
|
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
|
2022-04-26 10:11:47 +00:00
|
|
|
|
2022-06-23 08:35:05 +00:00
|
|
|
def to_replicate_(self):
|
2022-11-02 08:11:34 +00:00
|
|
|
"""to_replicate_
|
2022-08-16 01:21:05 +00:00
|
|
|
|
2022-06-23 08:35:05 +00:00
|
|
|
an inline member function, converting dist spec of the tensor to REPLICATE
|
|
|
|
"""
|
2022-07-11 07:51:48 +00:00
|
|
|
self._redistribute(dist_spec=ReplicaSpec())
|
2022-06-23 08:35:05 +00:00
|
|
|
|
|
|
|
def to_replicate(self) -> 'ColoTensor':
|
|
|
|
"""to_replicate
|
2022-08-16 01:21:05 +00:00
|
|
|
|
|
|
|
converting dist spec of the tensor to ReplicaSpec()
|
2022-06-23 08:35:05 +00:00
|
|
|
"""
|
2022-07-11 07:51:48 +00:00
|
|
|
return self.redistribute(ReplicaSpec())
|
|
|
|
|
2022-05-19 04:44:59 +00:00
|
|
|
@staticmethod
|
2022-07-06 08:15:16 +00:00
|
|
|
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
2022-08-16 01:21:05 +00:00
|
|
|
"""from_torch_tensor
|
|
|
|
|
|
|
|
A static method builds a `ColoTensor` from a PyTorch Tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor.
|
|
|
|
spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
ColoTensor: a ColoTensor
|
|
|
|
"""
|
2022-05-19 04:44:59 +00:00
|
|
|
tensor = tensor.as_subclass(ColoTensor)
|
|
|
|
tensor.__init__(tensor, spec=spec)
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
def __deepcopy__(self, memo):
|
|
|
|
if id(self) in memo:
|
|
|
|
return memo[id(self)]
|
2022-05-10 08:04:08 +00:00
|
|
|
else:
|
2022-05-19 04:44:59 +00:00
|
|
|
with torch._C.DisableTorchFunction():
|
|
|
|
data = self.data.clone()
|
2022-07-06 08:15:16 +00:00
|
|
|
tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec)))
|
2022-05-19 04:44:59 +00:00
|
|
|
memo[id(self)] = tensor
|
2022-06-27 01:45:26 +00:00
|
|
|
return tensor
|
|
|
|
|
2022-07-21 02:53:15 +00:00
|
|
|
# override builtin functions which must use tensor in replicate placement #
|
2022-06-27 01:45:26 +00:00
|
|
|
|
2022-07-21 02:53:15 +00:00
|
|
|
def size_local(self, *args) -> torch.Size:
|
|
|
|
with torch._C.DisableTorchFunction():
|
|
|
|
return super().size(*args)
|
2022-06-27 01:45:26 +00:00
|
|
|
|
2022-07-21 02:53:15 +00:00
|
|
|
def size_global(self, *args) -> torch.Size:
|
2022-08-16 01:21:05 +00:00
|
|
|
"""size_global
|
|
|
|
|
|
|
|
override the torch buildin size()
|
2022-06-27 01:45:26 +00:00
|
|
|
the shape passed in must be in a replicate placement.
|
2022-08-16 01:21:05 +00:00
|
|
|
|
2022-06-27 01:45:26 +00:00
|
|
|
Returns:
|
2022-08-16 01:21:05 +00:00
|
|
|
torch.Size: the global tensor shape
|
2022-06-27 01:45:26 +00:00
|
|
|
"""
|
2022-07-06 08:15:16 +00:00
|
|
|
if self.is_replicate():
|
2022-07-21 02:53:15 +00:00
|
|
|
return self.size_local(*args)
|
2022-07-06 08:15:16 +00:00
|
|
|
spec = self.dist_spec
|
2022-06-27 01:45:26 +00:00
|
|
|
dims = spec.dims
|
|
|
|
num_partitions = spec.num_partitions
|
|
|
|
# import inspect
|
|
|
|
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
|
2022-07-21 02:53:15 +00:00
|
|
|
size_list = list(self.size_local())
|
2022-06-27 01:45:26 +00:00
|
|
|
for dim, num_partition in zip(dims, num_partitions):
|
|
|
|
size_list[dim] *= num_partition
|
2022-07-21 02:53:15 +00:00
|
|
|
if args == ():
|
2022-06-27 01:45:26 +00:00
|
|
|
return torch.Size(size_list)
|
2022-07-21 02:53:15 +00:00
|
|
|
else:
|
|
|
|
return size_list[args[0]]
|
2022-07-06 08:15:16 +00:00
|
|
|
|
2023-01-28 06:35:25 +00:00
|
|
|
def numel_global(self):
|
|
|
|
"""Returns the number of elements in the tensor when it's replicated.
|
|
|
|
"""
|
|
|
|
return math.prod(self.size_global())
|
|
|
|
|
2022-07-06 08:15:16 +00:00
|
|
|
# Some API for dist spec check
|
|
|
|
|
|
|
|
def is_replicate(self):
|
|
|
|
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
2022-07-15 06:02:32 +00:00
|
|
|
or (len(self.dist_spec.num_partitions) == 1
|
|
|
|
and self.dist_spec.num_partitions[0] == 1) \
|
|
|
|
or (self.process_group.tp_world_size() == 1)
|
2022-07-06 08:15:16 +00:00
|
|
|
|
|
|
|
def is_shard_1dcol(self):
|
|
|
|
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
2022-07-15 06:02:32 +00:00
|
|
|
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
2022-07-06 08:15:16 +00:00
|
|
|
|
|
|
|
def is_shard_1drow(self):
|
|
|
|
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
2022-07-15 06:02:32 +00:00
|
|
|
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
2022-07-08 06:55:27 +00:00
|
|
|
|
|
|
|
def is_sharded(self):
|
2022-07-11 05:05:44 +00:00
|
|
|
return self.dist_spec.placement == DistPlacementPattern.SHARD
|