ColossalAI/colossalai/tensor/d_tensor/api.py

460 lines
16 KiB
Python
Raw Normal View History

import copy
import operator
from functools import reduce
from typing import Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.device.device_mesh import DeviceMesh
from .layout import Layout
from .layout_converter import LayoutConverter
from .sharding_spec import ShardingSpec
layout_converter = LayoutConverter()
def clear_layout_converter():
global layout_converter
layout_converter.cached_solution.clear()
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a distributed tensor.
"""
return hasattr(tensor, "dist_layout")
def is_sharded(dtensor: torch.Tensor) -> bool:
"""
Check if a tensor is sharded.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: True if the tensor is sharded, False otherwise.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
dtensor._old_detach = dtensor.detach
dtensor._old_clone = dtensor.clone
def new_detach(self):
t_ = self._old_detach()
t_.dist_layout = copy.deepcopy(self.dist_layout)
return t_
def new_clone(self, *args, **kwargs):
t_ = self._old_clone(*args, **kwargs)
t_.dist_layout = copy.deepcopy(self.dist_layout)
return t_
# bind the new methods to the tensor
dtensor.detach = new_detach.__get__(dtensor)
dtensor.clone = new_clone.__get__(dtensor)
return dtensor
def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
'''
Construct the default sharding specification for the tensor.
Args:
tensor (`torch.Tensor`): the tensor to be sharded.
Returns:
A `ShardingSpec` object without any sharding specified.
'''
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
def _apply_layout(tensor, layout):
'''
Apply the layout to the local tensor during initializing process.
'''
# layout converter requires a source and target laytout
# we construct the source layer for an unsharded tensor
# and use self.dist_layer as the targer layout for the sharded tensor
source_spec = _construct_default_sharding_spec(tensor)
source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape)
sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout)
return sharded_tensor
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
"""
Convert the given tensor to a distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be converted.
device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices.
sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded.
Returns:
torch.Tensor: The distributed tensor.
"""
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
# shard tensor
sharded_tensor = _apply_layout(tensor, dist_layout)
# hack some tensor methods
_hijack_detach_and_clone(sharded_tensor)
return sharded_tensor
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
'''
Convert the layout of the tensor from source_spec to target_spec.
This will update the `local_tensor` and `dist_layout` in place.
Args:
dtensor (torch.Tensor): the distributed tensor to be converted.
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
target_layout (Layout): the target layout specification.
'''
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
global_shape = get_global_shape(dtensor)
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
resharded_tensor = layout_converter.apply(tensor=dtensor,
source_layout=dtensor.dist_layout,
target_layout=target_layout)
return resharded_tensor
def to_global(dtensor: torch.Tensor) -> torch.Tensor:
"""
Convert a distributed tensor to the global tensor with the given layout.
This function returns a native `torch.Tensor` object.
Args:
dtensor (torch.Tensor): the distributed tensor to be converted.
Returns:
torch.Tensor: the global tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
device_mesh = get_device_mesh(dtensor)
global_shape = get_global_shape(dtensor)
global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape)
global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout)
return global_tensor
def shard_rowwise(
tensor: torch.Tensor,
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
) -> torch.Tensor:
"""
Shard the first dim of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be sharded.
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
If None, the tensor will be sharded with respect to the global process group.
Defaults to None.
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns:
torch.Tensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
return distribute_tensor(tensor, device_mesh, sharding_spec)
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor:
"""
Shard the first dim of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be sharded.
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
If None, the tensor will be sharded with respect to the global process group.
Defaults to None.
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns:
torch.Tensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
return distribute_tensor(tensor, device_mesh, sharding_spec)
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
param.dist_layout = dtensor.dist_layout
_hijack_detach_and_clone(param)
return param
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
param.data = dtensor
# make it distributed as well
param.dist_layout = dtensor.dist_layout
_hijack_detach_and_clone(param)
def compute_global_numel(dtensor: torch.Tensor) -> int:
"""
Compute the global number of elements in the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
int: The global number of elements in the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
return numel
def get_layout(dtensor: torch.Tensor) -> Layout:
"""
Get the layout of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
Layout: The layout of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout
def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
"""
Get the global shape of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
torch.Size: The global shape of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout.global_shape
def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
"""
Get the device mesh of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
DeviceMesh: The device mesh of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout.device_mesh
def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
"""
Get the sharding spec of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
ShardingSpec: The sharding spec of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout.sharding_spec
# ======================================================
# Some sharding does not obey the SPMD style
# e.g. Fused QKV layer in GPT2
# we support customize sharding with the following APIs
# ======================================================
def is_customized_distributed_tensor(tensor: torch.Tensor):
"""
Check whether the given tensor is a customized distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a customized distributed tensor.
"""
return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn')
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
dtensor._old_detach = dtensor.detach
dtensor._old_clone = dtensor.clone
def new_detach(self):
t_ = self._old_detach()
t_.shard_fn = self.shard_fn
t_.gather_fn = self.gather_fn
return t_
def new_clone(self, *args, **kwargs):
t_ = self._old_clone(*args, **kwargs)
t_.shard_fn = self.shard_fn
t_.gather_fn = self.gather_fn
return t_
# bind the new methods to the tensor
dtensor.detach = new_detach.__get__(dtensor)
dtensor.clone = new_clone.__get__(dtensor)
return dtensor
def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable):
"""
Distribute the given tensor with the given shard_fn and gather_fn.
Example:
```python
# define shard and gather functions
def shard_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
return tensor.chunk(world_size, dim=0)[rank]
def gather_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(shard_list, tensor)
return torch.cat(shard_list, dim=0)
# create a distributed tensor
tensor = torch.rand(4, 4)
dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn)
```
Args:
tensor (torch.Tensor): The tensor to be distributed.
shard_fn (callable): The function to shard the tensor.
gather_fn (callable): The function to gather the tensor.
Returns:
torch.Tensor: The distributed tensor.
"""
assert callable(shard_fn), 'The shard_fn must be callable.'
assert callable(gather_fn), 'The gather_fn must be callable.'
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
sharded_tensor = shard_fn(tensor)
# set the shard_fn and gather_fn as attributes of the distributed tensor
sharded_tensor.shard_fn = shard_fn
sharded_tensor.gather_fn = gather_fn
# set the shard_fn and gather_fn as attributes of the distributed tensor
_hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor)
return sharded_tensor
def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
"""
Gather the given tensor to the global tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
torch.Tensor: The global tensor.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
return dtensor.gather_fn(dtensor)
def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
"""
Convert the given customized distributed tensor to a parameter.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
param.shard_fn = dtensor.shard_fn
param.gather_fn = dtensor.gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
return param
def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter):
"""
Convert the given customized distributed tensor to an existing parameter.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
param.data = dtensor.data
param.shard_fn = dtensor.shard_fn
param.gather_fn = dtensor.gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor(param)