mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
521 lines
18 KiB
521 lines
18 KiB
import copy
|
|
import operator
|
|
from functools import reduce
|
|
from typing import Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from colossalai.device.device_mesh import DeviceMesh
|
|
|
|
from .layout import Layout
|
|
from .layout_converter import LayoutConverter
|
|
from .sharding_spec import ShardingSpec
|
|
|
|
layout_converter = LayoutConverter()
|
|
|
|
|
|
def clear_layout_converter():
|
|
global layout_converter
|
|
layout_converter.cached_solution.clear()
|
|
|
|
|
|
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
|
|
"""
|
|
Check whether the given tensor is a distributed tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be checked.
|
|
|
|
Returns:
|
|
bool: Whether the given tensor is a distributed tensor.
|
|
"""
|
|
return hasattr(tensor, "dist_layout")
|
|
|
|
|
|
def is_sharded(dtensor: torch.Tensor) -> bool:
|
|
"""
|
|
Check if a tensor is sharded.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be checked.
|
|
|
|
Returns:
|
|
bool: True if the tensor is sharded, False otherwise.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
|
|
|
|
|
|
def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be hijacked.
|
|
|
|
Returns:
|
|
torch.Tensor: The hijacked tensor.
|
|
"""
|
|
dtensor._old_detach = dtensor.detach
|
|
dtensor._old_clone = dtensor.clone
|
|
|
|
def new_detach(self):
|
|
t_ = self._old_detach()
|
|
t_.dist_layout = copy.deepcopy(self.dist_layout)
|
|
return t_
|
|
|
|
def new_clone(self, *args, **kwargs):
|
|
t_ = self._old_clone(*args, **kwargs)
|
|
t_.dist_layout = copy.deepcopy(self.dist_layout)
|
|
return t_
|
|
|
|
# bind the new methods to the tensor
|
|
dtensor.detach = new_detach.__get__(dtensor)
|
|
dtensor.clone = new_clone.__get__(dtensor)
|
|
return dtensor
|
|
|
|
|
|
def _construct_default_sharding_spec(
|
|
tensor: torch.Tensor,
|
|
) -> ShardingSpec:
|
|
"""
|
|
Construct the default sharding specification for the tensor.
|
|
|
|
Args:
|
|
tensor (`torch.Tensor`): the tensor to be sharded.
|
|
|
|
Returns:
|
|
A `ShardingSpec` object without any sharding specified.
|
|
"""
|
|
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
|
|
|
|
|
|
def _apply_layout(tensor, layout):
|
|
"""
|
|
Apply the layout to the local tensor during initializing process.
|
|
"""
|
|
# layout converter requires a source and target layout
|
|
# we construct the source layer for an unsharded tensor
|
|
# and use self.dist_layer as the target layout for the sharded tensor
|
|
source_spec = _construct_default_sharding_spec(tensor)
|
|
source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape)
|
|
sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout)
|
|
return sharded_tensor
|
|
|
|
|
|
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
|
|
"""
|
|
Convert the given tensor to a distributed tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be converted.
|
|
device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices.
|
|
sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded.
|
|
|
|
Returns:
|
|
torch.Tensor: The distributed tensor.
|
|
"""
|
|
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
|
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
|
|
|
|
# shard tensor
|
|
sharded_tensor = _apply_layout(tensor, dist_layout)
|
|
|
|
# hack some tensor methods
|
|
_hijack_detach_and_clone(sharded_tensor)
|
|
|
|
return sharded_tensor
|
|
|
|
def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor:
|
|
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
|
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
|
|
|
# shard tensor
|
|
tensor.dist_layout = dist_layout
|
|
|
|
# hack some tensor methods
|
|
_hijack_detach_and_clone(tensor)
|
|
|
|
return tensor
|
|
|
|
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
|
"""
|
|
Convert the layout of the tensor from source_spec to target_spec.
|
|
This will update the `local_tensor` and `dist_layout` in place.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): the distributed tensor to be converted.
|
|
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
|
|
target_layout (Layout): the target layout specification.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
global_shape = get_global_shape(dtensor)
|
|
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
|
resharded_tensor = layout_converter.apply(
|
|
tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout
|
|
)
|
|
return resharded_tensor
|
|
|
|
|
|
def to_global(dtensor: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Convert a distributed tensor to the global tensor with the given layout.
|
|
This function returns a native `torch.Tensor` object.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): the distributed tensor to be converted.
|
|
|
|
Returns:
|
|
torch.Tensor: the global tensor.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
layout_converter = LayoutConverter()
|
|
|
|
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
|
|
device_mesh = get_device_mesh(dtensor)
|
|
global_shape = get_global_shape(dtensor)
|
|
global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape)
|
|
|
|
global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout)
|
|
return global_tensor
|
|
|
|
|
|
def shard_rowwise(
|
|
tensor: torch.Tensor,
|
|
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Shard the first dim of the given tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be sharded.
|
|
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
|
|
If None, the tensor will be sharded with respect to the global process group.
|
|
Defaults to None.
|
|
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
|
|
|
Returns:
|
|
torch.Tensor: The sharded tensor.
|
|
"""
|
|
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
|
if group_or_device_mesh is None:
|
|
group_or_device_mesh = dist.GroupMember.WORLD
|
|
|
|
if isinstance(group_or_device_mesh, ProcessGroup):
|
|
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
|
else:
|
|
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
|
|
device_mesh = group_or_device_mesh
|
|
|
|
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
|
|
|
return distribute_tensor(tensor, device_mesh, sharding_spec)
|
|
|
|
|
|
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor:
|
|
"""
|
|
Shard the first dim of the given tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be sharded.
|
|
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
|
|
If None, the tensor will be sharded with respect to the global process group.
|
|
Defaults to None.
|
|
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
|
|
|
Returns:
|
|
torch.Tensor: The sharded tensor.
|
|
"""
|
|
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
|
if group_or_device_mesh is None:
|
|
group_or_device_mesh = dist.GroupMember.WORLD
|
|
|
|
if isinstance(group_or_device_mesh, ProcessGroup):
|
|
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
|
else:
|
|
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
|
|
device_mesh = group_or_device_mesh
|
|
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
|
|
|
return distribute_tensor(tensor, device_mesh, sharding_spec)
|
|
|
|
|
|
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
|
|
|
# make it distributed as well
|
|
param.dist_layout = dtensor.dist_layout
|
|
_hijack_detach_and_clone(param)
|
|
|
|
return param
|
|
|
|
|
|
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
param.data = dtensor
|
|
# make it distributed as well
|
|
param.dist_layout = dtensor.dist_layout
|
|
_hijack_detach_and_clone(param)
|
|
|
|
|
|
def compute_global_numel(dtensor: torch.Tensor) -> int:
|
|
"""
|
|
Compute the global number of elements in the distributed tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
int: The global number of elements in the distributed tensor.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
|
|
return numel
|
|
|
|
|
|
def get_layout(dtensor: torch.Tensor) -> Layout:
|
|
"""
|
|
Get the layout of the distributed tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
Layout: The layout of the distributed tensor.
|
|
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
return dtensor.dist_layout
|
|
|
|
|
|
def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
|
|
"""
|
|
Get the global shape of the distributed tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
torch.Size: The global shape of the distributed tensor.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
return dtensor.dist_layout.global_shape
|
|
|
|
|
|
def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
|
|
"""
|
|
Get the device mesh of the distributed tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
DeviceMesh: The device mesh of the distributed tensor.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
return dtensor.dist_layout.device_mesh
|
|
|
|
|
|
def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
|
|
"""
|
|
Get the sharding spec of the distributed tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
ShardingSpec: The sharding spec of the distributed tensor.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
return dtensor.dist_layout.sharding_spec
|
|
|
|
|
|
# ======================================================
|
|
# Some sharding does not obey the SPMD style
|
|
# e.g. Fused QKV layer in GPT2
|
|
# we support customize sharding with the following APIs
|
|
# ======================================================
|
|
def is_customized_distributed_tensor(tensor: torch.Tensor):
|
|
"""
|
|
Check whether the given tensor is a customized distributed tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be checked.
|
|
|
|
Returns:
|
|
bool: Whether the given tensor is a customized distributed tensor.
|
|
"""
|
|
return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn")
|
|
|
|
|
|
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be hijacked.
|
|
|
|
Returns:
|
|
torch.Tensor: The hijacked tensor.
|
|
"""
|
|
dtensor._old_detach = dtensor.detach
|
|
dtensor._old_clone = dtensor.clone
|
|
|
|
def new_detach(self):
|
|
t_ = self._old_detach()
|
|
t_.shard_fn = self.shard_fn
|
|
t_.gather_fn = self.gather_fn
|
|
return t_
|
|
|
|
def new_clone(self, *args, **kwargs):
|
|
t_ = self._old_clone(*args, **kwargs)
|
|
t_.shard_fn = self.shard_fn
|
|
t_.gather_fn = self.gather_fn
|
|
return t_
|
|
|
|
# bind the new methods to the tensor
|
|
dtensor.detach = new_detach.__get__(dtensor)
|
|
dtensor.clone = new_clone.__get__(dtensor)
|
|
return dtensor
|
|
|
|
|
|
def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable):
|
|
"""
|
|
Distribute the given tensor with the given shard_fn and gather_fn.
|
|
|
|
Example:
|
|
|
|
```python
|
|
# define shard and gather functions
|
|
def shard_fn(tensor):
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
return tensor.chunk(world_size, dim=0)[rank]
|
|
|
|
def gather_fn(tensor):
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
|
|
torch.distributed.all_gather(shard_list, tensor)
|
|
return torch.cat(shard_list, dim=0)
|
|
|
|
# create a distributed tensor
|
|
tensor = torch.rand(4, 4)
|
|
dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn)
|
|
```
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be distributed.
|
|
shard_fn (callable): The function to shard the tensor.
|
|
gather_fn (callable): The function to gather the tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The distributed tensor.
|
|
"""
|
|
assert callable(shard_fn), "The shard_fn must be callable."
|
|
assert callable(gather_fn), "The gather_fn must be callable."
|
|
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
|
|
|
sharded_tensor = shard_fn(tensor)
|
|
|
|
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
|
sharded_tensor.shard_fn = shard_fn
|
|
sharded_tensor.gather_fn = gather_fn
|
|
|
|
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
|
_hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor)
|
|
|
|
return sharded_tensor
|
|
|
|
|
|
def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gather_fn: callable):
|
|
"""
|
|
Distribute the given tensor with the given shard_fn and gather_fn.
|
|
|
|
Example:
|
|
|
|
```python
|
|
# define shard and gather functions
|
|
def shard_fn(tensor):
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
return tensor.chunk(world_size, dim=0)[rank]
|
|
|
|
def gather_fn(tensor):
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
|
|
torch.distributed.all_gather(shard_list, tensor)
|
|
return torch.cat(shard_list, dim=0)
|
|
|
|
# create a distributed tensor
|
|
tensor = torch.rand(4, 4)
|
|
dtensor = init_tensor_as_customization_distributed(tensor, shard_fn, gather_fn)
|
|
```
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be distributed.
|
|
shard_fn (callable): The function to shard the tensor.
|
|
gather_fn (callable): The function to gather the tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The distributed tensor.
|
|
"""
|
|
assert callable(shard_fn), "The shard_fn must be callable."
|
|
assert callable(gather_fn), "The gather_fn must be callable."
|
|
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
|
|
|
|
|
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
|
tensor.shard_fn = shard_fn
|
|
tensor.gather_fn = gather_fn
|
|
|
|
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
|
_hijack_detach_and_clone_for_customized_distributed_tensor(tensor)
|
|
|
|
return tensor
|
|
|
|
|
|
def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Gather the given tensor to the global tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The global tensor.
|
|
"""
|
|
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
|
return dtensor.gather_fn(dtensor)
|
|
|
|
|
|
def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
|
"""
|
|
Convert the given customized distributed tensor to a parameter.
|
|
"""
|
|
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
|
|
|
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
|
|
|
# make it distributed as well
|
|
param.shard_fn = dtensor.shard_fn
|
|
param.gather_fn = dtensor.gather_fn
|
|
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
|
|
return param
|
|
|
|
|
|
def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter):
|
|
"""
|
|
Convert the given customized distributed tensor to an existing parameter.
|
|
"""
|
|
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
|
|
|
param.data = dtensor.data
|
|
param.shard_fn = dtensor.shard_fn
|
|
param.gather_fn = dtensor.gather_fn
|
|
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
|