import copy import operator from functools import reduce from typing import Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup from colossalai.device.device_mesh import DeviceMesh from .layout import Layout from .layout_converter import LayoutConverter from .sharding_spec import ShardingSpec layout_converter = LayoutConverter() def clear_layout_converter(): global layout_converter layout_converter.cached_solution.clear() def is_distributed_tensor(tensor: torch.Tensor) -> bool: """ Check whether the given tensor is a distributed tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: bool: Whether the given tensor is a distributed tensor. """ return hasattr(tensor, "dist_layout") def is_sharded(dtensor: torch.Tensor) -> bool: """ Check if a tensor is sharded. Args: tensor (torch.Tensor): The tensor to be checked. Returns: bool: True if the tensor is sharded, False otherwise. """ assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return list(dtensor.shape) == list(dtensor.dist_layout.global_shape) def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor: """ Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. Args: tensor (torch.Tensor): The tensor to be hijacked. Returns: torch.Tensor: The hijacked tensor. """ dtensor._old_detach = dtensor.detach dtensor._old_clone = dtensor.clone def new_detach(self): t_ = self._old_detach() t_.dist_layout = copy.deepcopy(self.dist_layout) return t_ def new_clone(self, *args, **kwargs): t_ = self._old_clone(*args, **kwargs) t_.dist_layout = copy.deepcopy(self.dist_layout) return t_ # bind the new methods to the tensor dtensor.detach = new_detach.__get__(dtensor) dtensor.clone = new_clone.__get__(dtensor) return dtensor def _construct_default_sharding_spec( tensor: torch.Tensor, ) -> ShardingSpec: """ Construct the default sharding specification for the tensor. Args: tensor (`torch.Tensor`): the tensor to be sharded. Returns: A `ShardingSpec` object without any sharding specified. """ return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) def _apply_layout(tensor, layout): """ Apply the layout to the local tensor during initializing process. """ # layout converter requires a source and target layout # we construct the source layer for an unsharded tensor # and use self.dist_layer as the target layout for the sharded tensor source_spec = _construct_default_sharding_spec(tensor) source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape) sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout) return sharded_tensor def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: """ Convert the given tensor to a distributed tensor. Args: tensor (torch.Tensor): The tensor to be converted. device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices. sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded. Returns: torch.Tensor: The distributed tensor. """ assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape) # shard tensor sharded_tensor = _apply_layout(tensor, dist_layout) # hack some tensor methods _hijack_detach_and_clone(sharded_tensor) return sharded_tensor def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor: assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) # shard tensor tensor.dist_layout = dist_layout # hack some tensor methods _hijack_detach_and_clone(tensor) return tensor def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: """ Convert the layout of the tensor from source_spec to target_spec. This will update the `local_tensor` and `dist_layout` in place. Args: dtensor (torch.Tensor): the distributed tensor to be converted. device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices. target_layout (Layout): the target layout specification. """ assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." global_shape = get_global_shape(dtensor) target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) resharded_tensor = layout_converter.apply( tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout ) return resharded_tensor def to_global(dtensor: torch.Tensor) -> torch.Tensor: """ Convert a distributed tensor to the global tensor with the given layout. This function returns a native `torch.Tensor` object. Args: dtensor (torch.Tensor): the distributed tensor to be converted. Returns: torch.Tensor: the global tensor. """ assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." layout_converter = LayoutConverter() global_sharding_spec = ShardingSpec(dtensor.dim(), {}) device_mesh = get_device_mesh(dtensor) global_shape = get_global_shape(dtensor) global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape) global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout) return global_tensor def shard_rowwise( tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, ) -> torch.Tensor: """ Shard the first dim of the given tensor. Args: tensor (torch.Tensor): The tensor to be sharded. group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor. If None, the tensor will be sharded with respect to the global process group. Defaults to None. inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. Returns: torch.Tensor: The sharded tensor. """ # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group if group_or_device_mesh is None: group_or_device_mesh = dist.GroupMember.WORLD if isinstance(group_or_device_mesh, ProcessGroup): device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) else: assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding." device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) return distribute_tensor(tensor, device_mesh, sharding_spec) def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor: """ Shard the first dim of the given tensor. Args: tensor (torch.Tensor): The tensor to be sharded. group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor. If None, the tensor will be sharded with respect to the global process group. Defaults to None. inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. Returns: torch.Tensor: The sharded tensor. """ # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group if group_or_device_mesh is None: group_or_device_mesh = dist.GroupMember.WORLD if isinstance(group_or_device_mesh, ProcessGroup): device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) else: assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding." device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) return distribute_tensor(tensor, device_mesh, sharding_spec) def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) # make it distributed as well param.dist_layout = dtensor.dist_layout _hijack_detach_and_clone(param) return param def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None: assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." param.data = dtensor # make it distributed as well param.dist_layout = dtensor.dist_layout _hijack_detach_and_clone(param) def compute_global_numel(dtensor: torch.Tensor) -> int: """ Compute the global number of elements in the distributed tensor. Args: dtensor (torch.Tensor): The distributed tensor. Returns: int: The global number of elements in the distributed tensor. """ assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." numel = reduce(operator.mul, dtensor.dist_layout.global_shape) return numel def get_layout(dtensor: torch.Tensor) -> Layout: """ Get the layout of the distributed tensor. Args: dtensor (torch.Tensor): The distributed tensor. Returns: Layout: The layout of the distributed tensor. """ assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout def get_global_shape(dtensor: torch.Tensor) -> torch.Size: """ Get the global shape of the distributed tensor. Args: dtensor (torch.Tensor): The distributed tensor. Returns: torch.Size: The global shape of the distributed tensor. """ assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout.global_shape def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh: """ Get the device mesh of the distributed tensor. Args: dtensor (torch.Tensor): The distributed tensor. Returns: DeviceMesh: The device mesh of the distributed tensor. """ assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout.device_mesh def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec: """ Get the sharding spec of the distributed tensor. Args: dtensor (torch.Tensor): The distributed tensor. Returns: ShardingSpec: The sharding spec of the distributed tensor. """ assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout.sharding_spec # ====================================================== # Some sharding does not obey the SPMD style # e.g. Fused QKV layer in GPT2 # we support customize sharding with the following APIs # ====================================================== def is_customized_distributed_tensor(tensor: torch.Tensor): """ Check whether the given tensor is a customized distributed tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: bool: Whether the given tensor is a customized distributed tensor. """ return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn") def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: """ Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. Args: tensor (torch.Tensor): The tensor to be hijacked. Returns: torch.Tensor: The hijacked tensor. """ dtensor._old_detach = dtensor.detach dtensor._old_clone = dtensor.clone def new_detach(self): t_ = self._old_detach() t_.shard_fn = self.shard_fn t_.gather_fn = self.gather_fn return t_ def new_clone(self, *args, **kwargs): t_ = self._old_clone(*args, **kwargs) t_.shard_fn = self.shard_fn t_.gather_fn = self.gather_fn return t_ # bind the new methods to the tensor dtensor.detach = new_detach.__get__(dtensor) dtensor.clone = new_clone.__get__(dtensor) return dtensor def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable): """ Distribute the given tensor with the given shard_fn and gather_fn. Example: ```python # define shard and gather functions def shard_fn(tensor): rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() return tensor.chunk(world_size, dim=0)[rank] def gather_fn(tensor): rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() shard_list = [torch.zeros_like(tensor) for _ in range(world_size)] torch.distributed.all_gather(shard_list, tensor) return torch.cat(shard_list, dim=0) # create a distributed tensor tensor = torch.rand(4, 4) dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn) ``` Args: tensor (torch.Tensor): The tensor to be distributed. shard_fn (callable): The function to shard the tensor. gather_fn (callable): The function to gather the tensor. Returns: torch.Tensor: The distributed tensor. """ assert callable(shard_fn), "The shard_fn must be callable." assert callable(gather_fn), "The gather_fn must be callable." assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." sharded_tensor = shard_fn(tensor) # set the shard_fn and gather_fn as attributes of the distributed tensor sharded_tensor.shard_fn = shard_fn sharded_tensor.gather_fn = gather_fn # set the shard_fn and gather_fn as attributes of the distributed tensor _hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor) return sharded_tensor def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gather_fn: callable): """ Distribute the given tensor with the given shard_fn and gather_fn. Example: ```python # define shard and gather functions def shard_fn(tensor): rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() return tensor.chunk(world_size, dim=0)[rank] def gather_fn(tensor): rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() shard_list = [torch.zeros_like(tensor) for _ in range(world_size)] torch.distributed.all_gather(shard_list, tensor) return torch.cat(shard_list, dim=0) # create a distributed tensor tensor = torch.rand(4, 4) dtensor = init_tensor_as_customization_distributed(tensor, shard_fn, gather_fn) ``` Args: tensor (torch.Tensor): The tensor to be distributed. shard_fn (callable): The function to shard the tensor. gather_fn (callable): The function to gather the tensor. Returns: torch.Tensor: The distributed tensor. """ assert callable(shard_fn), "The shard_fn must be callable." assert callable(gather_fn), "The gather_fn must be callable." assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." # set the shard_fn and gather_fn as attributes of the distributed tensor tensor.shard_fn = shard_fn tensor.gather_fn = gather_fn # set the shard_fn and gather_fn as attributes of the distributed tensor _hijack_detach_and_clone_for_customized_distributed_tensor(tensor) return tensor def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: """ Gather the given tensor to the global tensor. Args: dtensor (torch.Tensor): The distributed tensor. Returns: torch.Tensor: The global tensor. """ assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." return dtensor.gather_fn(dtensor) def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): """ Convert the given customized distributed tensor to a parameter. """ assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) # make it distributed as well param.shard_fn = dtensor.shard_fn param.gather_fn = dtensor.gather_fn _hijack_detach_and_clone_for_customized_distributed_tensor(param) return param def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter): """ Convert the given customized distributed tensor to an existing parameter. """ assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." param.data = dtensor.data param.shard_fn = dtensor.shard_fn param.gather_fn = dtensor.gather_fn _hijack_detach_and_clone_for_customized_distributed_tensor(param)