diff --git a/colossalai/device/README.md b/colossalai/device/README.md new file mode 100644 index 000000000..8f835735b --- /dev/null +++ b/colossalai/device/README.md @@ -0,0 +1,73 @@ +# 🗄 Device + +## 📚 Table of Contents + +- [🗄 Device](#-device) + - [📚 Table of Contents](#-table-of-contents) + - [🔗 Introduction](#-introduction) + - [📝 Design](#-design) + - [🔨 Usage](#-usage) + +## 🔗 Introduction + +This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network. + +## 📝 Design + + +This module is inspired by the DeviceMesh in the [Alpa project](https://github.com/alpa-projects/alpa) and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future. + + +## 🔨 Usage + +- Create a device mesh + +```python +# this is the list of global ranks involved in the device mesh +# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3 +physical_mesh_id = torch.arange(4) +mesh_shape = [2, 2] +device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) +``` + +- View the mesh + + +```python +# view the mesh shape +# expect output +# [2, 2] +print(device_mesh.shape) + + +# view the logical mesh with global ranks +# expect output +# [ +# [0, 1], +# [2, 3] +# ] +print(device_mesh.logical_mesh_id) + +# view the number of devices in the mesh +# expect output +# 4 +print(device_mesh.num_devices) + +``` + +- Initialize the process group + +```python +# intialize process group +device_mesh.init_logical_process_group() + + +# get the process group for a rank with respect to an axis +# this is the process group involving global ranks 0 and 2 +print(device_mesh.get_process_group(axis=0, global_rank=0)) + +# get the ranks in the process with respect to an axis +# expect output +# [0, 2] +print(device_mesh.get_ranks_in_process_group(axis=0, global_rank=0)) +``` diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 2a5f747fb..0490a4401 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -3,11 +3,19 @@ with some changes. """ import operator +from dataclasses import dataclass from functools import reduce -from typing import List, Tuple +from typing import Dict, List, Union import torch import torch.distributed as dist +from torch.distributed import ProcessGroup + + +@dataclass +class ProcessGroupContainer: + process_group: ProcessGroup + ranks: List[int] # modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py) @@ -27,9 +35,11 @@ class DeviceMesh: during initializing the DeviceMesh instance if the init_process_group set to True. Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. (default: False) - need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True. + device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda') """ + _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"} + def __init__(self, physical_mesh_id: torch.Tensor, mesh_shape: torch.Size = None, @@ -37,48 +47,140 @@ class DeviceMesh: mesh_alpha: List[float] = None, mesh_beta: List[float] = None, init_process_group: bool = False, - need_flatten: bool = True): - self.physical_mesh_id = physical_mesh_id + device: str = 'cuda'): + # ============================ + # Physical & Logical Mesh IDs + # ============================ + self._physical_mesh_id = physical_mesh_id + assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor." + + # logical mesh ids can be obtained via two ways + # 1. provide physical mesh id and provide mesh shape + # 2. directly supply the logical mesh id + assert mesh_shape is None or logical_mesh_id is None, \ + "Only one of mesh_shape and logical_mesh_id can be specified." \ + "Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id" + if logical_mesh_id is None: self.mesh_shape = mesh_shape - self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) + self._logical_mesh_id = self._physical_mesh_id.reshape(self.mesh_shape) else: self._logical_mesh_id = logical_mesh_id self.mesh_shape = self._logical_mesh_id.shape - # map global rank into logical rank - self.convert_map = {} - self._global_rank_to_logical_rank_map(self._logical_mesh_id, []) + # ensure two things: + # 1. logical and physical mesh IDs should contain the same elements + # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed + assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \ + "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." + assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \ + "Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again." + assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \ + "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." + + # =============================================== # coefficient for alpha-beta communication model + # alpha is latency and beta is bandwidth + # =============================================== + # if the values are not provided, we assume they are 1 for simplicity if mesh_alpha is None: mesh_alpha = [1] * len(self.mesh_shape) if mesh_beta is None: mesh_beta = [1] * len(self.mesh_shape) + self.mesh_alpha = tuple(mesh_alpha) self.mesh_beta = tuple(mesh_beta) - self.init_process_group = init_process_group - self.need_flatten = need_flatten - if self.init_process_group: - self.process_groups_dict = self.create_process_groups_for_logical_mesh() - if self.need_flatten and self._logical_mesh_id.dim() > 1: - self.flatten_device_mesh = self.flatten() - # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) - # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, - # self.mesh_beta) + + # ensure the alpha and beta have the same shape + assert len(self.mesh_alpha) == len(self.mesh_beta), \ + "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." + + # ========================= + # Device for Process Group + # ========================= + self._device = device + self._dist_backend = self._DIST_BACKEND[device] + + # ========================= + # Process Group Management + # ========================= + # the _global_to_local_rank_mapping is structured as follows + # { + # : [ , , , ...] + # } + self._global_to_local_rank_mapping = dict() + self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping, + tensor=self.logical_mesh_id) + + # create process group + self._process_group_dict = {} + self._ranks_in_the_process_group = {} + self._global_rank_of_current_process = None + self._is_initialized = False + + # initialize process group if specified + self._init_ranks_in_the_same_group() + self._init_process_group = init_process_group + if init_process_group: + self.init_logical_process_group() @property - def shape(self): + def shape(self) -> torch.Size: + """ + Return the shape of the logical mesh. + """ return self.mesh_shape @property - def num_devices(self): - return reduce(operator.mul, self.physical_mesh_id.shape, 1) + def num_devices(self) -> int: + """ + Return the number of devices contained in the device mesh. + """ + return reduce(operator.mul, self._physical_mesh_id.shape, 1) @property - def logical_mesh_id(self): + def logical_mesh_id(self) -> torch.Tensor: + """ + Return the logical mesh id. + """ return self._logical_mesh_id - def __deepcopy__(self, memo): + def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup: + """ + Return the process group on the specified axis. + + Args: + axis (int): the axis of the process group. + global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None) + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + return self._process_group_dict[global_rank][axis] + + def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]: + """ + Return the process groups for all axes. + + Args: + global_rank (int, optional): the global rank of the process + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + return self._process_group_dict[global_rank] + + def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]: + """ + Return the ranks in the process group on the specified axis. + + Args: + axis (int): the axis of the process group. + global_rank (int, optional): the global rank of the process + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + return self._ranks_in_the_process_group[global_rank][axis] + + def __deepcopy__(self, memo) -> "DeviceMesh": cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result @@ -86,111 +188,206 @@ class DeviceMesh: if k != 'process_groups_dict': setattr(result, k, __import__("copy").deepcopy(v, memo)) else: + # process group cannot be copied + # thus, we share them directly setattr(result, k, v) - return result + def _init_global_to_logical_rank_mapping(self, + mapping: Dict, + tensor: torch.Tensor, + index_list: List[int] = []) -> Dict[int, List[int]]: + """ + Build a global rank to local rank mapping for each process group in different axis in the logical device mesh. + + Args: + mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. + tensor (torch.Tensor): the tensor that contains the logical mesh ids. + index_list (List[int]) + + Returns: + mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. + The value is a list of integers and each integer represents the local rank in the indexed axis. + """ + for index, inner_tensor in enumerate(tensor): + # index means the local rank in the current axis + # inner_tensor refers to the processes with the same local rank + + if inner_tensor.numel() == 1: + # if the inner_tensor only has one element, it means that + # it already reaches the last axis + # we append its local_rank in the last axis to the index_list + # and assign to the mapping + # the value of the mapping is the the local rank at the indexed axis of the device mesh + mapping[int(inner_tensor)] = index_list + [index] + else: + # we recursively go into the function until we reach the last axis + # meanwhile, we should add the local rank in the current axis in the index_list + self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index]) + + def init_logical_process_group(self): + ''' + This method is used to initialize the logical process groups which will be used in communications + among logical device mesh. + Note: if init_process_group set to False, you have to call this method manually. Otherwise, + the communication related function, such as ShapeConsistencyManager.apply will raise errors. + ''' + # sanity check + assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group" + assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice" + + # update the global rank of the current process + self._global_rank_of_current_process = dist.get_rank() + duplicate_check_list = [] + + # flatten the global ranks to 1D list + global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() + + for global_rank in global_rank_flatten_list: + # find the other ranks which are in the same process group as global_rank + ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) + + for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): + # skip duplicated process group creation + if ranks_in_same_group in duplicate_check_list: + continue + + # create the process group + pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend) + + # keep this process group in the process_groups_dict + for rank in ranks_in_same_group: + if rank not in self._process_group_dict: + self._process_group_dict[rank] = dict() + self._process_group_dict[rank][axis] = pg_handler + + # update the init flag + # we only allow init for once + self._is_initialized = True + + def _init_ranks_in_the_same_group(self): + """ + This method is used to initialize the ranks_in_the_same_group dictionary. + """ + # flatten the global ranks to 1D list + global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() + + for global_rank in global_rank_flatten_list: + # find the other ranks which are in the same process group as global_rank + ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) + + for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): + # create dict for each rank + if global_rank not in self._process_group_dict: + self._ranks_in_the_process_group[global_rank] = dict() + + # keep this process group in the process_groups_dict + self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group + + def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]: + """ + Return the local rank of the given global rank in the logical device mesh. + + Args: + rank (int): the global rank in the logical device mesh. + axis (int): the axis of the logical device mesh. + """ + local_ranks = self._global_to_local_rank_mapping[rank] + if axis: + return local_ranks[axis] + else: + return local_ranks + + def _collate_global_ranks_in_same_process_group(self, global_rank): + ''' + Give a global rank and return all global ranks involved in its associated process group in each axis. + + Example: + + ```python + sphysical_mesh_id = torch.arange(0, 16) + mesh_shape = (4, 4) + + # logical mesh will look like + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + print(device_mesh.collate_global_ranks_in_same_process_group(0)) + + # key is axis name + # value is a list of global ranks in same axis with rank 0 + # output will look like + # { + 0: [0, 4, 8, 12], + 1: [0, 1, 2, 3] + # } + ''' + # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping + # for self._global_to_local_rank_mapping + # the key is the global rank + # the value is the list of local ranks corresponding to the global rank with respect of different axes + # we can see the list of local ranks as the process coordinates for simplicity + # the key and value are all unique, therefore, + # we can also to use the coordinates to find the global rank + + # ========================================================================= + # Step 1 + # find all the process_coordinates for processes in the same process group + # as the given global rank + # ========================================================================= + + # each + processes_in_the_same_process_group = {} + + for dim in range(self.logical_mesh_id.dim()): + # iterate over the dimension size so that we can include all processes + # in the same process group in the given axis + # the _local_rank refers to the local rank of the current process + for _local_rank in range(self.logical_mesh_id.shape[dim]): + + # if this dimension is not initailized yet, + # initialize it with an empty array + if dim not in processes_in_the_same_process_group: + processes_in_the_same_process_group[dim] = [] + + # get the local rank corresponding to the global rank + process_coordinates = self._global_to_local_rank_mapping[global_rank].copy() + + # replace the local rank in the given dimension with the + # lcoal rank of the current process iterated + process_coordinates[dim] = _local_rank + processes_in_the_same_process_group[dim].append(process_coordinates) + + # ================================================================= + # Step 2 + # Use local rank combination to find its corresponding global rank + # ================================================================= + # the key of the dict is the axis + # the value is the list of global ranks which are in the same process group as the given global rank + global_pg_ranks = {} + for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items(): + global_pg_ranks[dim] = [] + for process_coordinates in coordinates_of_all_processes: + # find the global rank by local rank combination + for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items(): + if process_coordinates == _process_coordinates: + global_pg_ranks[dim].append(_global_rank) + return global_pg_ranks + def flatten(self): """ Flatten the logical mesh into an effective 1d logical mesh, """ flatten_mesh_shape_size = len(self.mesh_shape) flatten_mesh_shape = [self.num_devices] - return DeviceMesh(self.physical_mesh_id, + return DeviceMesh(self._physical_mesh_id, tuple(flatten_mesh_shape), mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), - init_process_group=self.init_process_group, - need_flatten=False) - - def _global_rank_to_logical_rank_map(self, tensor, index_list): - ''' - This method is a helper function to build convert_map recursively. - ''' - for index, inner_tensor in enumerate(tensor): - if inner_tensor.numel() == 1: - self.convert_map[int(inner_tensor)] = index_list + [index] - else: - self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index]) - - def create_process_groups_for_logical_mesh(self): - ''' - This method is used to initialize the logical process groups which will be used in communications - among logical device mesh. - Note: if init_process_group set to False, you have to call this method manually. Otherwise, - the communication related function, such as ShapeConsistencyManager.apply will raise errors. - ''' - process_groups_dict = {} - check_duplicate_list = [] - global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist() - for global_rank in global_rank_flatten_list: - process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank) - for axis, process_group in process_groups.items(): - if axis not in process_groups_dict: - process_groups_dict[axis] = [] - if process_group not in check_duplicate_list: - check_duplicate_list.append(process_group) - process_group_handler = dist.new_group(process_group) - process_groups_dict[axis].append((process_group, process_group_handler)) - - return process_groups_dict - - def global_rank_to_logical_rank(self, rank): - return self.convert_map[rank] - - def global_rank_to_process_groups_with_logical_rank(self, rank): - ''' - Give a global rank and return all logical process groups of this rank. - for example: - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) - mesh_shape = (4, 4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - print(device_mesh.global_rank_to_process_groups_with_logical_rank(0)) - output: - # key is axis name - # value is a list of logical ranks in same axis with rank 0 - {0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]} - ''' - process_groups = {} - for d in range(self.logical_mesh_id.dim()): - for replacer in range(self.logical_mesh_id.shape[d]): - if d not in process_groups: - process_groups[d] = [] - process_group_member = self.convert_map[rank].copy() - process_group_member[d] = replacer - process_groups[d].append(process_group_member) - return process_groups - - def global_rank_to_process_groups_with_global_rank(self, rank): - ''' - Give a global rank and return all process groups of this rank. - for example: - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) - mesh_shape = (4, 4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - print(device_mesh.global_rank_to_process_groups_with_global_rank(0)) - output: - # key is axis name - # value is a list of global ranks in same axis with rank 0 - {0: [0, 4, 8, 12], 1: [0, 1, 2, 3]} - ''' - logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank) - process_groups = {} - for dim, logical_ranks in logical_process_groups.items(): - process_groups[dim] = [] - for logical_rank in logical_ranks: - for g_rank, l_rank in self.convert_map.items(): - if l_rank == logical_rank: - process_groups[dim].append(g_rank) - return process_groups + init_process_group=self._init_process_group) def all_gather_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] @@ -212,38 +409,3 @@ class DeviceMesh: penalty_factor = num_devices / 2.0 return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) - - -class FlattenDeviceMesh(DeviceMesh): - - def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None): - super().__init__(physical_mesh_id, - mesh_shape, - mesh_alpha, - mesh_beta, - init_process_group=False, - need_flatten=False) - # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars - self.mesh_alpha = max(self.mesh_alpha) - self.mesh_beta = min(self.mesh_beta) - # Different from original process_groups_dict, rank_list is not stored - self.process_number_dict = self.create_process_numbers_for_logical_mesh() - - def create_process_numbers_for_logical_mesh(self): - ''' - Build 1d DeviceMesh in column-major(0) and row-major(1) - for example: - mesh_shape = (2,4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7]] - # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} - ''' - num_devices = reduce(operator.mul, self.mesh_shape, 1) - process_numbers_dict = {} - process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist() - process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist() - return process_numbers_dict - - def mix_gather_cost(self, num_bytes): - num_devices = reduce(operator.mul, self.mesh_shape, 1) - return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 76f550dc4..ca8914362 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,5 +1,5 @@ from types import MethodType -from typing import Callable, Optional, Union +from typing import Callable, Dict, Optional, Union import torch import torch.distributed as dist @@ -8,8 +8,9 @@ from torch import Tensor from torch.utils._pytree import tree_map from colossalai._analyzer._subclasses import MetaTensor +from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.d_tensor.d_tensor import DTensor -from colossalai.tensor.d_tensor.layout import Layout +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ @@ -172,7 +173,7 @@ class LazyTensor(torch.Tensor): self.clean() return _convert_cls(self, target) - def distribute(self, layout: Layout) -> torch.Tensor: + def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. Args: @@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor): """ target = self._materialize_data() self.clean() - local_tensor = DTensor(target, layout).local_tensor + local_tensor = DTensor(target, device_mesh, sharding_spec).local_tensor return _convert_cls(self, local_tensor) def clean(self) -> None: @@ -536,7 +537,10 @@ class LazyInitContext: return _apply_to_lazy_module(module, apply_fn, verbose) @staticmethod - def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: + def distribute(module: nn.Module, + device_mesh: DeviceMesh, + sharding_spec_dict: Dict[str, ShardingSpec], + verbose: bool = False) -> nn.Module: """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: @@ -546,7 +550,7 @@ class LazyInitContext: """ def apply_fn(name: str, p: LazyTensor): - p.distribute(layout_dict[name]) + p.distribute(device_mesh, sharding_spec_dict[name]) return _apply_to_lazy_module(module, apply_fn, verbose) diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index af38d2a50..dd873c852 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) - for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) + for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) + ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor, comm_spec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor, comm_spec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor, comm_spec, async_op=False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor def _mix_gather(tensor, comm_spec): @@ -414,7 +411,7 @@ class CommSpec: self.forward_only = forward_only if isinstance(self.logical_process_axis, list): if not mix_gather: - self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.device_mesh = self.sharding_spec.device_mesh.flatten() self.logical_process_axis = 0 else: self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes diff --git a/colossalai/tensor/d_tensor/RAEDME.md b/colossalai/tensor/d_tensor/RAEDME.md new file mode 100644 index 000000000..95d866388 --- /dev/null +++ b/colossalai/tensor/d_tensor/RAEDME.md @@ -0,0 +1,103 @@ +# 🔢 Distributed Tensor + +## 📚 Table of Contents + +- [🔢 Distributed Tensor](#-distributed-tensor) + - [📚 Table of Contents](#-table-of-contents) + - [🔗 Introduction](#-introduction) + - [📝 Design](#-design) + - [🔨 Usage](#-usage) + - [🎈 Progress Log](#-progress-log) + +## 🔗 Introduction + +Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training. +It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor. + +## 📝 Design + +Our implementation is inspired by the work [Alpa](https://arxiv.org/abs/2201.12023), which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations `S` to represent the sharded dimension and `R` to represent the replicated dimension. For example, given a 2D matrix, `[S, R]` represents the tensor is sharded over the first dimension. + +Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below: + + +```text + [1, 2, 3, 4 ] +A = [4, 5, 6, 7 ] + [8, 9, 10, 11] + [12, 13, 14, 15] +``` + +`[S0, R]` would mean that the first dimension is sharded over the rows in the device topology. + +```text +| --------------------—————————————————————-| +| | | +| [1, 2, 3, 4 ] | [1, 2, 3, 4 ] | +| [4, 5, 6, 7 ] | [4, 5, 6, 7 ] | +| | | +| --------------------——————————————————----- +| | | +| [8, 9, 10, 11] | [8, 9, 10, 11] | +| [12, 13, 14, 15] | [12, 13, 14, 15] | +| | | +| --------------------——————————————————----- +``` + +`[S01, R]` would mean that the first dimension is sharded over both the row and column in the device topology. + +```text +| --------------------—————————————————————-| +| | | +| [1, 2, 3, 4 ] | [4, 5, 6, 7 ] | +| | | +| --------------------——————————————————----- +| | | +| [8, 9, 10, 11] | [12, 13, 14, 15] | +| | | +| --------------------——————————————————----- +``` + +## 🔨 Usage + +A sample API usage is given below. + +```python +import torch + +import colossalai +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.d_tensor import DTensor, ShardingSpec + +colossalai.launch_from_torch(config={}) + +# define your device mesh +# assume you have 4 GPUs +physical_mesh_id = torch.arange(0, 4).reshape(1, 4) +mesh_shape = (2, 2) +device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + +# define a tensor +a = torch.rand(16, 32).cuda() + +# create sharding spec for the tensor +# assume the sharding spec is [S0, R] +dim_partition_dict = {0: [0]} +sharding_spec = ShardingSpec(a.dim(), dim_partition_dict) + +# create a distributed tensor +d_tensor = DTensor(a, device_mesh, sharding_spec) +print(d_tensor) + +global_tensor = d_tensor.to_global() +print(global_tensor) +``` + + +## 🎈 Progress Log + +- [x] Support layout conversion +- [x] Support sharding on 2D device mesh +- [ ] Support sharding on 3D device mesh +- [ ] Support sharding 4D device mesh +- [ ] Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.) diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index e69de29bb..af77f4f0e 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -0,0 +1,4 @@ +from .d_tensor import DTensor +from .sharding_spec import ShardingSpec + +__all__ = ['DTensor', 'ShardingSpec'] diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 159125fa1..79b2e3ef9 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -24,12 +24,12 @@ class CommSpec: ''' Communication spec is used to record the communication action. It converts the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the - communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim + communication method, process_group_dict to determine the process groups, gather_dim and shard_dim to determine the buffer shape, and logical_process_axis Argument: - comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. - process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. + comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. + process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec. gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. @@ -37,7 +37,7 @@ class CommSpec: def __init__(self, comm_pattern: CollectiveCommPattern, - process_groups_dict: Dict, + process_group_dict: Dict, gather_dim: int = None, shard_dim: int = None, logical_process_axis: int = None): @@ -45,7 +45,7 @@ class CommSpec: self.gather_dim = gather_dim self.shard_dim = shard_dim self.logical_process_axis = logical_process_axis - self.process_groups_dict = process_groups_dict + self.process_group_dict = process_group_dict def __repr__(self): res_list = ["CommSpec:("] @@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor class _ReduceGrad(torch.autograd.Function): @@ -269,7 +257,7 @@ class _AllToAll(torch.autograd.Function): def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - process_groups_dict=comm_spec.process_groups_dict, + process_group_dict=comm_spec.process_group_dict, gather_dim=comm_spec.shard_dim, shard_dim=comm_spec.gather_dim, logical_process_axis=comm_spec.logical_process_axis) diff --git a/colossalai/tensor/d_tensor/d_tensor.py b/colossalai/tensor/d_tensor/d_tensor.py index c1fe9d50a..6bda0f4e5 100644 --- a/colossalai/tensor/d_tensor/d_tensor.py +++ b/colossalai/tensor/d_tensor/d_tensor.py @@ -3,55 +3,119 @@ from typing import Optional import torch from torch.utils._pytree import tree_map +from colossalai.device.device_mesh import DeviceMesh + from .layout import Layout from .layout_converter import LayoutConverter, to_global from .sharding_spec import ShardingSpec +__all__ = ['DTensor', 'distribute_tensor', 'distribute_module', 'construct_default_sharding_spec'] + layout_converter = LayoutConverter() class DTensor(torch.Tensor): + """ + DTensor stands for distributed tensor. It is a subclass of `torch.Tensor` and contains meta information + about the tensor distribution. The meta information includes the device mesh, the sharding specification, + and the entire shape of the tensor. - def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout): - self.local_tensor = local_tensor - self.data_type = local_tensor.dtype - self.entire_shape = local_tensor.shape + During runtime, we will not directly use the DTensor objects for computation. Instead, we will only use the + `DTensor.local_tensor` for computation. The `DTensor.local_tensor` is the local tensor in the current rank. + In this way, all tensors involved in computation will only be native PyTorch tensors. + + Example: + ```python + from colossalai.device import DeviceMesh + + # define your device mesh + # assume you have 4 GPUs + physical_mesh_id = torch.arange(0, 4).reshape(1, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + + # define a tensor + x = torch.rand(16, 32) + + # create sharding spec for the tensor + # assume the sharding spec is [S, R] + dim_partition_dict = { + 0: 1 + } + sharding_spec = ShardingSpec(a.dim(), dim_partition_dict) + + # create a distributed tensor + d_tensor = DTensor(x, device_mesh, sharding_spec) + ``` + + Args: + tensor (`torch.Tensor`): the unsharded tensor. + device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices. + sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded. + """ + + def __init__(self, tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec): + # ensure this tensor is not a DTensor + assert not isinstance(tensor, DTensor), 'The input tensor should not be a DTensor.' + + # store meta info + self.local_tensor = tensor + self.data_type = tensor.dtype + self.global_shape = tensor.shape + + # create distributed layout + dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape) self.dist_layout = dist_layout + + # shard the tensor self._apply_layout() @staticmethod - def __new__(cls, local_tensor, layout): - return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad) + def __new__(cls, tensor, *args, **kwargs): + return torch.Tensor._make_subclass(cls, tensor, tensor.requires_grad) def __repr__(self): - return f"DTensor({self.to_global()}, {self.dist_layout})" + return f"DTensor(\n{self.to_global()}\n{self.dist_layout}" def __str__(self): return self.__repr__() - def layout_convert(self, target_layout): + def layout_convert(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: ''' Convert the layout of the tensor from source_spec to target_spec. + This will update the `local_tensor` and `dist_layout` in place. + + Args: + target_layout (Layout): the target layout specification. ''' - self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape) + self.local_tensor = layout_converter.apply(tensor=self.local_tensor, + source_layout=self.dist_layout, + target_layout=target_layout) self.dist_layout = target_layout def _apply_layout(self): ''' Apply the layout to the local tensor during initializing process. ''' + # layout converter requires a source and target laytout + # we construct the source layer for an unsharded tensor + # and use self.dist_layer as the targer layout for the sharded tensor source_spec = construct_default_sharding_spec(self.local_tensor) source_layout = Layout(device_mesh=self.dist_layout.device_mesh, - device_type=self.dist_layout.device_type, sharding_spec=source_spec, - entire_shape=self.entire_shape) - self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout) + global_shape=self.global_shape) + self.local_tensor = layout_converter.apply(tensor=self.local_tensor, + source_layout=source_layout, + target_layout=self.dist_layout) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} + # convert all DTensors to native pytorch tensors + # so that operations will be conducted on native tensors def filter_arg(arg): if isinstance(arg, DTensor): return arg.local_tensor @@ -60,9 +124,9 @@ class DTensor(torch.Tensor): args = tree_map(filter_arg, args) kwargs = tree_map(filter_arg, kwargs) - # if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors - # and op type. + # NOTE: if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors + # and op type. return func(*args, **kwargs) @property @@ -85,7 +149,6 @@ class DTensor(torch.Tensor): ''' self.local_tensor = self.local_tensor.to(*args, **kwargs) self.data_type = self.local_tensor.dtype - self.dist_layout.device_type = self.local_tensor.device # TODO: update the device mesh process groups or we should just cache # both the cpu process groups and the cuda process groups? return self @@ -98,7 +161,7 @@ class DTensor(torch.Tensor): def to_global(self): ''' - Recover the global tensor from the distributed tensor. + Recover the global tensor from the distributed tensor by returning a new `torch.Tensor` object. Note: This function will all_gather the local tensor to the global tensor and it will not change the layout of the DTensor. This function is mainly used for debugging or @@ -107,24 +170,29 @@ class DTensor(torch.Tensor): return to_global(self.local_tensor, self.dist_layout) -def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor: +def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> DTensor: ''' Distribute the local tensor to the distributed tensor according to the dist_layout specified. Args: - local_tensor: tensor to be distributed. - dist_layout: the layout specification of the distributed tensor. + tensor (`torch.Tensor`): tensor to be distributed. + device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices. + sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded. Returns: A 'DTensor' object. ''' - return DTensor(local_tensor, dist_layout) + return DTensor(tensor, device_mesh, sharding_spec) def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module: ''' This function converts all the parameters in the module to DTensor(DParam). + Args: + module (`torch.nn.Module`): the module to be distributed. + partition_fn (callable): the partition function which will be used to partition the parameters. + Note: This function is subject to future change as the DParam has not been implemented yet. ''' for name, param in module.named_parameters(): @@ -138,5 +206,11 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: ''' Construct the default sharding specification for the tensor. + + Args: + tensor (`torch.Tensor`): the tensor to be sharded. + + Returns: + A `ShardingSpec` object without any sharding specified. ''' return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index ee7ef74a9..2946611b4 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -11,28 +11,32 @@ from .sharding_spec import ShardingSpec class Layout: - """Layout of a tensor. + """ + Layout of a tensor refers to the tensor placement on the device mesh and how the tensor is sharded over the devices. - Attributes: - device_mesh: the device mesh to store the tensor distributed. - device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'. - sharding_spec: the sharding specification to describe how the tensor is sharded. - entire_shape: the entire shape of the global tensor. + Args: + device_mesh (`DeviceMesh`): the device mesh to store the tensor distributed. + sharding_spec (`ShardingSpec`): the sharding specification to describe how the tensor is sharded. + global_shape (`torch.Size`): the entire shape of the global tensor. """ - def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, - entire_shape: torch.Size): + def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size): self.device_mesh = device_mesh - self.device_type = device_type self.sharding_spec = sharding_spec - self.entire_shape = entire_shape + self.global_shape = global_shape self._sanity_check() def __hash__(self) -> int: return hash(f'{self.sharding_spec}') - def get_sharded_shape_per_device(self): - sharded_shape = list(self.entire_shape) + def get_sharded_shape_per_device(self) -> torch.Size: + """ + Compute the shape of the sharded tensor on each device. + + Returns: + `torch.Size`: the shape of the sharded tensor on each device. + """ + sharded_shape = list(self.global_shape) for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) @@ -56,7 +60,7 @@ class Layout: # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): - tensor_dim_size = self.entire_shape[dim] + tensor_dim_size = self.global_shape[dim] num_devices = 1 for element in shard_list: diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index cf02aac30..6eff92ea6 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -3,10 +3,8 @@ from copy import deepcopy from dataclasses import dataclass from typing import Dict, List, Tuple -import numpy as np import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout @@ -28,13 +26,21 @@ class LayoutConverterOptions: pass -def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor: +def to_global(distributed_tensor: "DTensor", layout: Layout) -> torch.Tensor: + """ + Convert a distributed tensor to the global tensor with the given layout. + This function returns a native `torch.Tensor` object. + + + Args: + distributed_tensor (`DTensor`): the distributed tensor to be converted. + layout (`Layout`): the target layout specification. + """ layout_converter = LayoutConverter() global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {}) global_layout = Layout(device_mesh=layout.device_mesh, - device_type=layout.device_type, sharding_spec=global_sharding_spec, - entire_shape=layout.entire_shape) + global_shape=layout.global_shape) with torch.no_grad(): global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout) return global_tensor @@ -49,6 +55,9 @@ def set_layout_converting_options(options: LayoutConverterOptions): class LayoutConverter(metaclass=SingletonMeta): + """ + LayoutConverter is a singleton class which converts the layout of a distributed tensor. + """ def __init__(self): self._options = None @@ -91,15 +100,14 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) for layout, comm_spec in rst_dict.items(): @@ -112,7 +120,12 @@ class LayoutConverter(metaclass=SingletonMeta): valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + for target_pair in source_spec.dim_partition_dict.items(): shard_list = all_gather_simulator(target_pair) index = target_pair[0] @@ -130,7 +143,7 @@ class LayoutConverter(metaclass=SingletonMeta): logical_process_axis = target_pair[1][-1] comm_spec = CommSpec( comm_pattern, - process_groups_dict=process_groups_dict, + process_group_dict=process_group_dict, gather_dim=gather_dim, # shard_dim will be used during backward shard_dim=gather_dim, @@ -141,8 +154,7 @@ class LayoutConverter(metaclass=SingletonMeta): new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: @@ -167,15 +179,14 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_to_all_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -188,7 +199,12 @@ class LayoutConverter(metaclass=SingletonMeta): ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + source_spec = source_layout.sharding_spec tensor_dims = source_spec.dims for f_index in range(tensor_dims - 1): @@ -229,7 +245,7 @@ class LayoutConverter(metaclass=SingletonMeta): shard_dim = f_index logical_process_axis = b_target_pair[1][-1] comm_spec = CommSpec(comm_pattern, - process_groups_dict, + process_group_dict=process_group_dict, gather_dim=gather_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -252,8 +268,7 @@ class LayoutConverter(metaclass=SingletonMeta): new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -278,16 +293,15 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0]} # [S0,R,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.shard_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -301,7 +315,11 @@ class LayoutConverter(metaclass=SingletonMeta): valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] # legal sharding dims means the mesh_id is still available to use. legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))] @@ -329,7 +347,7 @@ class LayoutConverter(metaclass=SingletonMeta): shard_dim = index logical_process_axis = shard_list[-1] comm_spec = CommSpec(comm_pattern, - process_groups_dict, + process_group_dict=process_group_dict, gather_dim=shard_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -340,8 +358,7 @@ class LayoutConverter(metaclass=SingletonMeta): dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -399,7 +416,7 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} @@ -407,16 +424,14 @@ class LayoutConverter(metaclass=SingletonMeta): # [R,S01,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [S01,R,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) @@ -505,21 +520,19 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) # [S0,R,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [R,S0,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) if rank in (0, 1): sharded_tensor_0 = torch.zeros(2, 1) @@ -554,3 +567,4 @@ class LayoutConverter(metaclass=SingletonMeta): for comm_spec in comm_action_sequence: tensor = comm_spec.covert_spec_to_action(tensor) return tensor + return tensor diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 3be057b3a..19d41d233 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -1,20 +1,19 @@ -from colossalai.device.device_mesh import DeviceMesh import torch +from colossalai.device.device_mesh import DeviceMesh + def test_device_mesh(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10,11], # [12,13,14,15]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - assert device_mesh.convert_map[5] == [1, 1] - assert device_mesh.convert_map[11] == [2, 3] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]] - assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] + assert device_mesh.global_rank_to_local_rank(5) == [1, 1] + assert device_mesh.global_rank_to_local_rank(11) == [2, 3] + assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3] if __name__ == '__main__': diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 2b7060c48..7c6339eff 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -20,16 +20,12 @@ def check_layer(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]} - logical_process_groups = device_mesh.process_groups_dict - for mesh_dim, pgs in logical_pg_dict.items(): - for index, pg in enumerate(pgs): - if rank in pg: - tensor = torch.ones(4).cuda() - group = logical_process_groups[mesh_dim][index][1] - dist.all_reduce(tensor, op=ReduceOp.SUM, group=group) - assert tensor.equal(tensor_to_check) + for axis in range(len(mesh_shape)): + tensor = torch.ones(4).cuda() + pg = device_mesh.get_process_group(axis=axis) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg) + assert tensor.equal(tensor_to_check) gpc.destroy() diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 85bfd0e27..2911012fa 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -6,7 +6,9 @@ import numpy as np import torch from packaging import version +from colossalai.device.device_mesh import DeviceMesh from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor +from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout_converter import to_global from tests.kit.model_zoo.registry import ModelAttribute @@ -81,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, print(f'{model.__class__.__name__} pass') -def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: +def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, + sharding_spec_dict: dict) -> None: state = model.state_dict() distributed_state = distributed_model.state_dict() @@ -91,6 +94,7 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn. assert n1 == n2 t1 = t1.cuda() t2 = t2.cuda() - if n2 in layout_dict: - t2 = to_global(t2, layout_dict[n2]) + if n2 in sharding_spec_dict: + layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape) + t2 = to_global(t2, layout) assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py index d515b175a..efa43eab5 100644 --- a/tests/test_lazy/test_distribute.py +++ b/tests/test_lazy/test_distribute.py @@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]: return dim -def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: +def make_sharding_spec(original_tensor: torch.Tensor) -> Layout: shard_dim = find_shard_dim(original_tensor.shape) dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=target_sharding_spec, - entire_shape=original_tensor.shape) - return layout + return target_sharding_spec def _get_current_name(prefix: str, name: str) -> str: return f'{prefix}.{name}'.lstrip('.') -def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: - layout_dict = {} +def generate_sharding_spec_dict(model: nn.Module) -> dict: + sharding_spec_dict = {} @torch.no_grad() def generate_recursively(module: nn.Module, prefix: str = ''): @@ -53,17 +49,17 @@ def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: # initialize tensors directly attached to the current module for name, param in module.named_parameters(recurse=False): if isinstance(param, LazyTensor): - layout = make_layout(device_mesh, param) - layout_dict[_get_current_name(prefix, name)] = layout + sharding_spec = make_sharding_spec(param) + sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec for name, buf in module.named_buffers(recurse=False): if isinstance(buf, LazyTensor): - layout = make_layout(device_mesh, buf) - layout_dict[_get_current_name(prefix, name)] = layout + sharding_spec = make_sharding_spec(buf) + sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec generate_recursively(model) - return layout_dict + return sharding_spec_dict @parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) @@ -85,9 +81,9 @@ def run_dist_lazy_init(subset, seed: int = 42): ctx = LazyInitContext() with ctx: deferred_model = model_fn() - layout_dict = generate_layout_dict(deferred_model, device_mesh) - ctx.distribute(deferred_model, layout_dict, verbose=True) - assert_dist_model_equal(model, deferred_model, layout_dict) + sharding_spec_dict = generate_sharding_spec_dict(deferred_model) + ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True) + assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict) def run_dist(rank, world_size, port) -> None: diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index d1f5b9299..0797e01e7 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -125,23 +125,6 @@ def check_all_reduce_bwd(process_groups_dict, rank): assert tensor_to_comm.equal(tensor_to_check) -def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank): - # tensor to comm - tensor_to_comm = torch.ones(2, 2).cuda() * rank - - # reduce through logical process axis 0 at flatten device mesh - # tensor to check - # tensor([[6., 6.], - # [6., 6.]]) - tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() - - # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) - comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0) - tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) - - assert tensor_to_comm.equal(tensor_to_check) - - def check_comm(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -153,24 +136,22 @@ def check_comm(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - process_groups_dict = device_mesh.process_groups_dict + + process_group_dict = device_mesh._process_group_dict[rank] # test all gather - check_all_gather(process_groups_dict, rank) + check_all_gather(process_group_dict, rank) # test shard - check_shard(process_groups_dict, rank) + check_shard(process_group_dict, rank) # test all to all - check_all_to_all(process_groups_dict, rank) + check_all_to_all(process_group_dict, rank) # test all reduce - check_all_reduce_fwd(process_groups_dict, rank) - check_all_reduce_bwd(process_groups_dict, rank) + check_all_reduce_fwd(process_group_dict, rank) + check_all_reduce_bwd(process_group_dict, rank) - flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict - # test all reduce in 1D flatten device mesh - check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank) gpc.destroy() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 3ca369acb..50a3bfb15 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -31,13 +31,9 @@ def check_dtensor(rank, world_size, port): device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=target_sharding_spec, - entire_shape=original_tensor.shape) - d_tensor = DTensor(original_tensor, layout) + d_tensor = DTensor(original_tensor, device_mesh, target_sharding_spec) - assert d_tensor.entire_shape == original_tensor.shape + assert d_tensor.global_shape == original_tensor.shape assert d_tensor.data_type == original_tensor.dtype if rank in (0, 1): @@ -57,12 +53,7 @@ def check_dtensor(rank, world_size, port): raise ValueError(f'rank {rank} is not in the device mesh') new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) - new_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=new_sharding_spec, - entire_shape=original_tensor.shape) - - d_tensor.layout_convert(new_layout) + d_tensor.layout_convert(device_mesh, new_sharding_spec) if rank == 0: assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) @@ -75,7 +66,7 @@ def check_dtensor(rank, world_size, port): else: raise ValueError(f'rank {rank} is not in the device mesh') - dtensor_from_local = distribute_tensor(original_tensor, new_layout) + dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec) if rank == 0: assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 5f56decb5..6608e4787 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn -entire_shape = torch.Size((64, 32, 16)) +global_shape = torch.Size((64, 32, 16)) layout_converter = LayoutConverter() -physical_mesh_id = torch.arange(0, 4).reshape(2, 2) +physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (2, 2) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec, - entire_shape=entire_shape) + layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) @@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) - layout_all2all = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_all2all, - entire_shape=entire_shape) + layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape) rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) @@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,R,R # device_mesh_shape: (4, 4) sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) - shard_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_shard, - entire_shape=entire_shape) + shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape) rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) @@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) @@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) - original_tensor = torch.rand(entire_shape).cuda() + original_tensor = torch.rand(global_shape).cuda() # tensor_to_apply: [R, S01, R] tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 6fe9ee292..859eef051 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -1,9 +1,10 @@ -from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern import torch -from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec -from colossalai.device.device_mesh import DeviceMesh -physical_mesh_id = torch.arange(0, 16).reshape(2, 8) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec + +physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index d66d4fec1..9bd9805e9 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -26,7 +26,7 @@ def run_dist(rank, world_size, port): # the mesh is in the following topo # [[0, 1], # [2, 3]] - physical_mesh_id = torch.arange(0, 4).reshape(2, 2) + physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) row_id = rank // 2 diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py index 909c84ef0..5007c4141 100644 --- a/tests/test_tensor/test_sharding_spec.py +++ b/tests/test_tensor/test_sharding_spec.py @@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec def test_sharding_spec(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7],