You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/tensor/d_tensor/api.py

541 lines
18 KiB

import copy
import operator
from functools import reduce
from typing import Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
from .layout import Layout
from .layout_converter import LayoutConverter
from .sharding_spec import ShardingSpec
layout_converter = LayoutConverter()
_SHARD_DIM = DimSpec([0])
def get_shard_dim_1d(p: torch.Tensor):
"""
Get the dimension along which the tensor is sharded, for example in 1D Tensor Parallel.
Args:
p (torch.Tensor): the input tensor
Returns:
int: the dimension along which the tensor is sharded
"""
if not is_distributed_tensor(p):
raise ValueError("p is not a distributed tensor")
sharding = p.dist_layout.sharding_spec.sharding_sequence
return sharding.index(_SHARD_DIM)
def clear_layout_converter():
global layout_converter
layout_converter.cached_solution.clear()
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a distributed tensor.
"""
return hasattr(tensor, "dist_layout")
def is_sharded(dtensor: torch.Tensor) -> bool:
"""
Check if a tensor is sharded.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: True if the tensor is sharded, False otherwise.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
dtensor._old_detach = dtensor.detach
dtensor._old_clone = dtensor.clone
def new_detach(self):
t_ = self._old_detach()
t_.dist_layout = copy.deepcopy(self.dist_layout)
return t_
def new_clone(self, *args, **kwargs):
t_ = self._old_clone(*args, **kwargs)
t_.dist_layout = copy.deepcopy(self.dist_layout)
return t_
# bind the new methods to the tensor
dtensor.detach = new_detach.__get__(dtensor)
dtensor.clone = new_clone.__get__(dtensor)
return dtensor
def _construct_default_sharding_spec(
tensor: torch.Tensor,
) -> ShardingSpec:
"""
Construct the default sharding specification for the tensor.
Args:
tensor (`torch.Tensor`): the tensor to be sharded.
Returns:
A `ShardingSpec` object without any sharding specified.
"""
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
def _apply_layout(tensor, layout):
"""
Apply the layout to the local tensor during initializing process.
"""
# layout converter requires a source and target layout
# we construct the source layer for an unsharded tensor
# and use self.dist_layer as the target layout for the sharded tensor
source_spec = _construct_default_sharding_spec(tensor)
source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape)
sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout)
return sharded_tensor
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
"""
Convert the given tensor to a distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be converted.
device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices.
sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded.
Returns:
torch.Tensor: The distributed tensor.
"""
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
# shard tensor
sharded_tensor = _apply_layout(tensor, dist_layout)
# hack some tensor methods
_hijack_detach_and_clone(sharded_tensor)
return sharded_tensor
def init_as_dtensor(
tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size
) -> torch.Tensor:
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
# shard tensor
tensor.dist_layout = dist_layout
# hack some tensor methods
_hijack_detach_and_clone(tensor)
return tensor
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
"""
Convert the layout of the tensor from source_spec to target_spec.
This will update the `local_tensor` and `dist_layout` in place.
Args:
dtensor (torch.Tensor): the distributed tensor to be converted.
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
target_layout (Layout): the target layout specification.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
global_shape = get_global_shape(dtensor)
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
resharded_tensor = layout_converter.apply(
tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout
)
return resharded_tensor
def to_global(dtensor: torch.Tensor) -> torch.Tensor:
"""
Convert a distributed tensor to the global tensor with the given layout.
This function returns a native `torch.Tensor` object.
Args:
dtensor (torch.Tensor): the distributed tensor to be converted.
Returns:
torch.Tensor: the global tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
device_mesh = get_device_mesh(dtensor)
global_shape = get_global_shape(dtensor)
global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape)
global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout)
return global_tensor
def shard_rowwise(
tensor: torch.Tensor,
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
) -> torch.Tensor:
"""
Shard the first dim of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be sharded.
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
If None, the tensor will be sharded with respect to the global process group.
Defaults to None.
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns:
torch.Tensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
return distribute_tensor(tensor, device_mesh, sharding_spec)
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor:
"""
Shard the first dim of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be sharded.
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
If None, the tensor will be sharded with respect to the global process group.
Defaults to None.
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns:
torch.Tensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
return distribute_tensor(tensor, device_mesh, sharding_spec)
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
param.dist_layout = dtensor.dist_layout
_hijack_detach_and_clone(param)
return param
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
param.data = dtensor
# make it distributed as well
param.dist_layout = dtensor.dist_layout
_hijack_detach_and_clone(param)
def compute_global_numel(dtensor: torch.Tensor) -> int:
"""
Compute the global number of elements in the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
int: The global number of elements in the distributed tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
return numel
def get_layout(dtensor: torch.Tensor) -> Layout:
"""
Get the layout of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
Layout: The layout of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout
def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
"""
Get the global shape of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
torch.Size: The global shape of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.global_shape
def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
"""
Get the device mesh of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
DeviceMesh: The device mesh of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.device_mesh
def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
"""
Get the sharding spec of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
ShardingSpec: The sharding spec of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.sharding_spec
# ======================================================
# Some sharding does not obey the SPMD style
# e.g. Fused QKV layer in GPT2
# we support customize sharding with the following APIs
# ======================================================
def is_customized_distributed_tensor(tensor: torch.Tensor):
"""
Check whether the given tensor is a customized distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a customized distributed tensor.
"""
return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn")
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
dtensor._old_detach = dtensor.detach
dtensor._old_clone = dtensor.clone
def new_detach(self):
t_ = self._old_detach()
t_.shard_fn = self.shard_fn
t_.gather_fn = self.gather_fn
return t_
def new_clone(self, *args, **kwargs):
t_ = self._old_clone(*args, **kwargs)
t_.shard_fn = self.shard_fn
t_.gather_fn = self.gather_fn
return t_
# bind the new methods to the tensor
dtensor.detach = new_detach.__get__(dtensor)
dtensor.clone = new_clone.__get__(dtensor)
return dtensor
def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable):
"""
Distribute the given tensor with the given shard_fn and gather_fn.
Example:
```python
# define shard and gather functions
def shard_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
return tensor.chunk(world_size, dim=0)[rank]
def gather_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(shard_list, tensor)
return torch.cat(shard_list, dim=0)
# create a distributed tensor
tensor = torch.rand(4, 4)
dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn)
```
Args:
tensor (torch.Tensor): The tensor to be distributed.
shard_fn (callable): The function to shard the tensor.
gather_fn (callable): The function to gather the tensor.
Returns:
torch.Tensor: The distributed tensor.
"""
assert callable(shard_fn), "The shard_fn must be callable."
assert callable(gather_fn), "The gather_fn must be callable."
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
sharded_tensor = shard_fn(tensor)
# set the shard_fn and gather_fn as attributes of the distributed tensor
sharded_tensor.shard_fn = shard_fn
sharded_tensor.gather_fn = gather_fn
# set the shard_fn and gather_fn as attributes of the distributed tensor
_hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor)
return sharded_tensor
def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gather_fn: callable):
"""
Distribute the given tensor with the given shard_fn and gather_fn.
Example:
```python
# define shard and gather functions
def shard_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
return tensor.chunk(world_size, dim=0)[rank]
def gather_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(shard_list, tensor)
return torch.cat(shard_list, dim=0)
# create a distributed tensor
tensor = torch.rand(4, 4)
dtensor = init_tensor_as_customization_distributed(tensor, shard_fn, gather_fn)
```
Args:
tensor (torch.Tensor): The tensor to be distributed.
shard_fn (callable): The function to shard the tensor.
gather_fn (callable): The function to gather the tensor.
Returns:
torch.Tensor: The distributed tensor.
"""
assert callable(shard_fn), "The shard_fn must be callable."
assert callable(gather_fn), "The gather_fn must be callable."
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
# set the shard_fn and gather_fn as attributes of the distributed tensor
tensor.shard_fn = shard_fn
tensor.gather_fn = gather_fn
# set the shard_fn and gather_fn as attributes of the distributed tensor
_hijack_detach_and_clone_for_customized_distributed_tensor(tensor)
return tensor
def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
"""
Gather the given tensor to the global tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
torch.Tensor: The global tensor.
"""
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
return dtensor.gather_fn(dtensor)
def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
"""
Convert the given customized distributed tensor to a parameter.
"""
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
param.shard_fn = dtensor.shard_fn
param.gather_fn = dtensor.gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
return param
def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter):
"""
Convert the given customized distributed tensor to an existing parameter.
"""
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
param.data = dtensor.data
param.shard_fn = dtensor.shard_fn
param.gather_fn = dtensor.gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor(param)