You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/tensor/d_tensor/d_tensor.py

143 lines
4.8 KiB

from typing import Optional
import torch
from torch.utils._pytree import tree_map
from .layout import Layout
from .layout_converter import LayoutConverter, to_global
from .sharding_spec import ShardingSpec
layout_converter = LayoutConverter()
class DTensor(torch.Tensor):
def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout):
self.local_tensor = local_tensor
self.data_type = local_tensor.dtype
self.entire_shape = local_tensor.shape
self.dist_layout = dist_layout
self._apply_layout()
@staticmethod
def __new__(cls, local_tensor, layout):
return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad)
def __repr__(self):
return f"DTensor({self.to_global()}, {self.dist_layout})"
def __str__(self):
return self.__repr__()
def layout_convert(self, target_layout):
'''
Convert the layout of the tensor from source_spec to target_spec.
'''
self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout)
self.dist_layout = target_layout
def _apply_layout(self):
'''
Apply the layout to the local tensor during initializing process.
'''
source_spec = construct_default_sharding_spec(self.local_tensor)
source_layout = Layout(device_mesh=self.dist_layout.device_mesh,
device_type=self.dist_layout.device_type,
sharding_spec=source_spec,
entire_shape=self.entire_shape)
self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
def filter_arg(arg):
if isinstance(arg, DTensor):
return arg.local_tensor
else:
return arg
args = tree_map(filter_arg, args)
kwargs = tree_map(filter_arg, kwargs)
# if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
# and op type.
return func(*args, **kwargs)
@property
def device_mesh(self):
'''
Return the device mesh of the tensor.
'''
return self.dist_layout.device_mesh
@property
def sharding_spec(self):
'''
Return the sharding specification of the tensor.
'''
return self.dist_layout.sharding_spec
def to(self, *args, **kwargs):
'''
Move the tensor to a new device or convert the tensor to a new dtype.
'''
self.local_tensor = self.local_tensor.to(*args, **kwargs)
self.data_type = self.local_tensor.dtype
self.dist_layout.device_type = self.local_tensor.device
# TODO: update the device mesh process groups or we should just cache
# both the cpu process groups and the cuda process groups?
return self
def to_local(self):
'''
Return the local tensor in this rank.
'''
return self.local_tensor
def to_global(self):
'''
Recover the global tensor from the distributed tensor.
Note: This function will all_gather the local tensor to the global tensor and it
will not change the layout of the DTensor. This function is mainly used for debugging or
check the correctness of the distributed tensor.
'''
return to_global(self.local_tensor, self.dist_layout)
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
'''
Distribute the local tensor to the distributed tensor according to the dist_layout specified.
Args:
local_tensor: tensor to be distributed.
dist_layout: the layout specification of the distributed tensor.
Returns:
A 'DTensor' object.
'''
return DTensor(local_tensor, dist_layout)
def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module:
'''
This function converts all the parameters in the module to DTensor(DParam).
Note: This function is subject to future change as the DParam has not been implemented yet.
'''
for name, param in module.named_parameters():
if param is not None and not isinstance(param, DTensor):
# TODO: we could convert the parameter to DParam here,
# the type of the parameter could be an optional argument.
setattr(module, name, torch.nn.Parameter(partition_fn(name, param.data)))
return module
def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
'''
Construct the default sharding specification for the tensor.
'''
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})