From ddcf58cacf9581d9c59a18f8276d52a061818fab Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 9 Jun 2023 09:41:27 +0800 Subject: [PATCH] Revert "[sync] sync feature/shardformer with develop" --- colossalai/device/README.md | 73 -- colossalai/device/device_mesh.py | 444 +++---- colossalai/lazy/lazy_init.py | 16 +- colossalai/nn/layer/parallel_1d/_operation.py | 1 - colossalai/shardformer/README.md | 296 ----- colossalai/shardformer/__init__.py | 0 colossalai/shardformer/layer/__init__.py | 0 colossalai/shardformer/layer/_operation.py | 97 -- .../shardformer/layer/dist_crossentropy.py | 105 -- colossalai/shardformer/layer/dropout.py | 58 - colossalai/shardformer/layer/layers.py | 1043 ----------------- colossalai/shardformer/model/__init__.py | 0 colossalai/shardformer/model/modeling_bert.py | 67 -- colossalai/shardformer/policies/__init__.py | 0 colossalai/shardformer/policies/autopolicy.py | 58 - colossalai/shardformer/policies/basepolicy.py | 217 ---- colossalai/shardformer/policies/bert.py | 170 --- colossalai/shardformer/policies/gpt2.py | 118 -- colossalai/shardformer/shard/__init__.py | 5 - colossalai/shardformer/shard/shard_config.py | 20 - colossalai/shardformer/shard/sharder.py | 266 ----- colossalai/shardformer/shard/slicer.py | 161 --- colossalai/shardformer/test/config.py | 1 - colossalai/shardformer/test/module_test.py | 50 - colossalai/shardformer/test/test.py | 124 -- colossalai/shardformer/utils/__init__.py | 0 colossalai/shardformer/utils/utils.py | 58 - colossalai/tensor/comm_spec.py | 89 +- colossalai/tensor/d_tensor/RAEDME.md | 103 -- colossalai/tensor/d_tensor/__init__.py | 4 - colossalai/tensor/d_tensor/comm_spec.py | 88 +- colossalai/tensor/d_tensor/d_tensor.py | 114 +- colossalai/tensor/d_tensor/layout.py | 30 +- .../tensor/d_tensor/layout_converter.py | 86 +- colossalai/tensor/d_tensor/sharding_spec.py | 31 +- docs/sidebars.json | 1 - docs/source/en/features/lazy_init.md | 71 -- docs/source/zh-Hans/features/lazy_init.md | 71 -- tests/test_device/test_device_mesh.py | 13 +- tests/test_device/test_init_logical_pg.py | 16 +- tests/test_lazy/lazy_init_utils.py | 10 +- tests/test_lazy/test_distribute.py | 28 +- .../test_dtensor/test_comm_spec.py | 33 +- .../test_tensor/test_dtensor/test_dtensor.py | 17 +- .../test_dtensor/test_layout_converter.py | 41 +- tests/test_tensor/test_shape_consistency.py | 7 +- tests/test_tensor/test_sharded_linear.py | 2 +- tests/test_tensor/test_sharding_spec.py | 2 +- 48 files changed, 437 insertions(+), 3868 deletions(-) delete mode 100644 colossalai/device/README.md delete mode 100644 colossalai/shardformer/README.md delete mode 100644 colossalai/shardformer/__init__.py delete mode 100644 colossalai/shardformer/layer/__init__.py delete mode 100644 colossalai/shardformer/layer/_operation.py delete mode 100644 colossalai/shardformer/layer/dist_crossentropy.py delete mode 100644 colossalai/shardformer/layer/dropout.py delete mode 100644 colossalai/shardformer/layer/layers.py delete mode 100644 colossalai/shardformer/model/__init__.py delete mode 100644 colossalai/shardformer/model/modeling_bert.py delete mode 100644 colossalai/shardformer/policies/__init__.py delete mode 100644 colossalai/shardformer/policies/autopolicy.py delete mode 100644 colossalai/shardformer/policies/basepolicy.py delete mode 100644 colossalai/shardformer/policies/bert.py delete mode 100644 colossalai/shardformer/policies/gpt2.py delete mode 100644 colossalai/shardformer/shard/__init__.py delete mode 100644 colossalai/shardformer/shard/shard_config.py delete mode 100644 colossalai/shardformer/shard/sharder.py delete mode 100644 colossalai/shardformer/shard/slicer.py delete mode 100644 colossalai/shardformer/test/config.py delete mode 100644 colossalai/shardformer/test/module_test.py delete mode 100644 colossalai/shardformer/test/test.py delete mode 100644 colossalai/shardformer/utils/__init__.py delete mode 100644 colossalai/shardformer/utils/utils.py delete mode 100644 colossalai/tensor/d_tensor/RAEDME.md delete mode 100644 docs/source/en/features/lazy_init.md delete mode 100644 docs/source/zh-Hans/features/lazy_init.md diff --git a/colossalai/device/README.md b/colossalai/device/README.md deleted file mode 100644 index 8f835735b..000000000 --- a/colossalai/device/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# 🗄 Device - -## 📚 Table of Contents - -- [🗄 Device](#-device) - - [📚 Table of Contents](#-table-of-contents) - - [🔗 Introduction](#-introduction) - - [📝 Design](#-design) - - [🔨 Usage](#-usage) - -## 🔗 Introduction - -This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network. - -## 📝 Design - - -This module is inspired by the DeviceMesh in the [Alpa project](https://github.com/alpa-projects/alpa) and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future. - - -## 🔨 Usage - -- Create a device mesh - -```python -# this is the list of global ranks involved in the device mesh -# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3 -physical_mesh_id = torch.arange(4) -mesh_shape = [2, 2] -device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) -``` - -- View the mesh - - -```python -# view the mesh shape -# expect output -# [2, 2] -print(device_mesh.shape) - - -# view the logical mesh with global ranks -# expect output -# [ -# [0, 1], -# [2, 3] -# ] -print(device_mesh.logical_mesh_id) - -# view the number of devices in the mesh -# expect output -# 4 -print(device_mesh.num_devices) - -``` - -- Initialize the process group - -```python -# intialize process group -device_mesh.init_logical_process_group() - - -# get the process group for a rank with respect to an axis -# this is the process group involving global ranks 0 and 2 -print(device_mesh.get_process_group(axis=0, global_rank=0)) - -# get the ranks in the process with respect to an axis -# expect output -# [0, 2] -print(device_mesh.get_ranks_in_process_group(axis=0, global_rank=0)) -``` diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 0490a4401..2a5f747fb 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -3,19 +3,11 @@ with some changes. """ import operator -from dataclasses import dataclass from functools import reduce -from typing import Dict, List, Union +from typing import List, Tuple import torch import torch.distributed as dist -from torch.distributed import ProcessGroup - - -@dataclass -class ProcessGroupContainer: - process_group: ProcessGroup - ranks: List[int] # modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py) @@ -35,11 +27,9 @@ class DeviceMesh: during initializing the DeviceMesh instance if the init_process_group set to True. Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. (default: False) - device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda') + need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True. """ - _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"} - def __init__(self, physical_mesh_id: torch.Tensor, mesh_shape: torch.Size = None, @@ -47,140 +37,48 @@ class DeviceMesh: mesh_alpha: List[float] = None, mesh_beta: List[float] = None, init_process_group: bool = False, - device: str = 'cuda'): - # ============================ - # Physical & Logical Mesh IDs - # ============================ - self._physical_mesh_id = physical_mesh_id - assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor." - - # logical mesh ids can be obtained via two ways - # 1. provide physical mesh id and provide mesh shape - # 2. directly supply the logical mesh id - assert mesh_shape is None or logical_mesh_id is None, \ - "Only one of mesh_shape and logical_mesh_id can be specified." \ - "Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id" - + need_flatten: bool = True): + self.physical_mesh_id = physical_mesh_id if logical_mesh_id is None: self.mesh_shape = mesh_shape - self._logical_mesh_id = self._physical_mesh_id.reshape(self.mesh_shape) + self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) else: self._logical_mesh_id = logical_mesh_id self.mesh_shape = self._logical_mesh_id.shape - # ensure two things: - # 1. logical and physical mesh IDs should contain the same elements - # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed - assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \ - "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." - assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \ - "Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again." - assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \ - "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." - - # =============================================== + # map global rank into logical rank + self.convert_map = {} + self._global_rank_to_logical_rank_map(self._logical_mesh_id, []) # coefficient for alpha-beta communication model - # alpha is latency and beta is bandwidth - # =============================================== - # if the values are not provided, we assume they are 1 for simplicity if mesh_alpha is None: mesh_alpha = [1] * len(self.mesh_shape) if mesh_beta is None: mesh_beta = [1] * len(self.mesh_shape) - self.mesh_alpha = tuple(mesh_alpha) self.mesh_beta = tuple(mesh_beta) - - # ensure the alpha and beta have the same shape - assert len(self.mesh_alpha) == len(self.mesh_beta), \ - "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." - - # ========================= - # Device for Process Group - # ========================= - self._device = device - self._dist_backend = self._DIST_BACKEND[device] - - # ========================= - # Process Group Management - # ========================= - # the _global_to_local_rank_mapping is structured as follows - # { - # : [ , , , ...] - # } - self._global_to_local_rank_mapping = dict() - self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping, - tensor=self.logical_mesh_id) - - # create process group - self._process_group_dict = {} - self._ranks_in_the_process_group = {} - self._global_rank_of_current_process = None - self._is_initialized = False - - # initialize process group if specified - self._init_ranks_in_the_same_group() - self._init_process_group = init_process_group - if init_process_group: - self.init_logical_process_group() + self.init_process_group = init_process_group + self.need_flatten = need_flatten + if self.init_process_group: + self.process_groups_dict = self.create_process_groups_for_logical_mesh() + if self.need_flatten and self._logical_mesh_id.dim() > 1: + self.flatten_device_mesh = self.flatten() + # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) + # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, + # self.mesh_beta) @property - def shape(self) -> torch.Size: - """ - Return the shape of the logical mesh. - """ + def shape(self): return self.mesh_shape @property - def num_devices(self) -> int: - """ - Return the number of devices contained in the device mesh. - """ - return reduce(operator.mul, self._physical_mesh_id.shape, 1) + def num_devices(self): + return reduce(operator.mul, self.physical_mesh_id.shape, 1) @property - def logical_mesh_id(self) -> torch.Tensor: - """ - Return the logical mesh id. - """ + def logical_mesh_id(self): return self._logical_mesh_id - def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup: - """ - Return the process group on the specified axis. - - Args: - axis (int): the axis of the process group. - global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None) - """ - if global_rank is None: - global_rank = self._global_rank_of_current_process - return self._process_group_dict[global_rank][axis] - - def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]: - """ - Return the process groups for all axes. - - Args: - global_rank (int, optional): the global rank of the process - """ - if global_rank is None: - global_rank = self._global_rank_of_current_process - return self._process_group_dict[global_rank] - - def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]: - """ - Return the ranks in the process group on the specified axis. - - Args: - axis (int): the axis of the process group. - global_rank (int, optional): the global rank of the process - """ - if global_rank is None: - global_rank = self._global_rank_of_current_process - return self._ranks_in_the_process_group[global_rank][axis] - - def __deepcopy__(self, memo) -> "DeviceMesh": + def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result @@ -188,206 +86,111 @@ class DeviceMesh: if k != 'process_groups_dict': setattr(result, k, __import__("copy").deepcopy(v, memo)) else: - # process group cannot be copied - # thus, we share them directly setattr(result, k, v) + return result - def _init_global_to_logical_rank_mapping(self, - mapping: Dict, - tensor: torch.Tensor, - index_list: List[int] = []) -> Dict[int, List[int]]: + def flatten(self): """ - Build a global rank to local rank mapping for each process group in different axis in the logical device mesh. - - Args: - mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. - tensor (torch.Tensor): the tensor that contains the logical mesh ids. - index_list (List[int]) - - Returns: - mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. - The value is a list of integers and each integer represents the local rank in the indexed axis. + Flatten the logical mesh into an effective 1d logical mesh, """ - for index, inner_tensor in enumerate(tensor): - # index means the local rank in the current axis - # inner_tensor refers to the processes with the same local rank + flatten_mesh_shape_size = len(self.mesh_shape) + flatten_mesh_shape = [self.num_devices] + return DeviceMesh(self.physical_mesh_id, + tuple(flatten_mesh_shape), + mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), + mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + init_process_group=self.init_process_group, + need_flatten=False) + def _global_rank_to_logical_rank_map(self, tensor, index_list): + ''' + This method is a helper function to build convert_map recursively. + ''' + for index, inner_tensor in enumerate(tensor): if inner_tensor.numel() == 1: - # if the inner_tensor only has one element, it means that - # it already reaches the last axis - # we append its local_rank in the last axis to the index_list - # and assign to the mapping - # the value of the mapping is the the local rank at the indexed axis of the device mesh - mapping[int(inner_tensor)] = index_list + [index] + self.convert_map[int(inner_tensor)] = index_list + [index] else: - # we recursively go into the function until we reach the last axis - # meanwhile, we should add the local rank in the current axis in the index_list - self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index]) + self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index]) - def init_logical_process_group(self): + def create_process_groups_for_logical_mesh(self): ''' This method is used to initialize the logical process groups which will be used in communications among logical device mesh. Note: if init_process_group set to False, you have to call this method manually. Otherwise, the communication related function, such as ShapeConsistencyManager.apply will raise errors. ''' - # sanity check - assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group" - assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice" - - # update the global rank of the current process - self._global_rank_of_current_process = dist.get_rank() - duplicate_check_list = [] - - # flatten the global ranks to 1D list - global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() - - for global_rank in global_rank_flatten_list: - # find the other ranks which are in the same process group as global_rank - ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) - - for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): - # skip duplicated process group creation - if ranks_in_same_group in duplicate_check_list: - continue - - # create the process group - pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend) - - # keep this process group in the process_groups_dict - for rank in ranks_in_same_group: - if rank not in self._process_group_dict: - self._process_group_dict[rank] = dict() - self._process_group_dict[rank][axis] = pg_handler - - # update the init flag - # we only allow init for once - self._is_initialized = True - - def _init_ranks_in_the_same_group(self): - """ - This method is used to initialize the ranks_in_the_same_group dictionary. - """ - # flatten the global ranks to 1D list - global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() - + process_groups_dict = {} + check_duplicate_list = [] + global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist() for global_rank in global_rank_flatten_list: - # find the other ranks which are in the same process group as global_rank - ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) - - for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): - # create dict for each rank - if global_rank not in self._process_group_dict: - self._ranks_in_the_process_group[global_rank] = dict() + process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank) + for axis, process_group in process_groups.items(): + if axis not in process_groups_dict: + process_groups_dict[axis] = [] + if process_group not in check_duplicate_list: + check_duplicate_list.append(process_group) + process_group_handler = dist.new_group(process_group) + process_groups_dict[axis].append((process_group, process_group_handler)) - # keep this process group in the process_groups_dict - self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group + return process_groups_dict - def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]: - """ - Return the local rank of the given global rank in the logical device mesh. + def global_rank_to_logical_rank(self, rank): + return self.convert_map[rank] - Args: - rank (int): the global rank in the logical device mesh. - axis (int): the axis of the logical device mesh. - """ - local_ranks = self._global_to_local_rank_mapping[rank] - if axis: - return local_ranks[axis] - else: - return local_ranks - - def _collate_global_ranks_in_same_process_group(self, global_rank): + def global_rank_to_process_groups_with_logical_rank(self, rank): ''' - Give a global rank and return all global ranks involved in its associated process group in each axis. - - Example: - - ```python - sphysical_mesh_id = torch.arange(0, 16) - mesh_shape = (4, 4) - - # logical mesh will look like - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - print(device_mesh.collate_global_ranks_in_same_process_group(0)) - - # key is axis name - # value is a list of global ranks in same axis with rank 0 - # output will look like - # { - 0: [0, 4, 8, 12], - 1: [0, 1, 2, 3] - # } + Give a global rank and return all logical process groups of this rank. + for example: + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + mesh_shape = (4, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + print(device_mesh.global_rank_to_process_groups_with_logical_rank(0)) + output: + # key is axis name + # value is a list of logical ranks in same axis with rank 0 + {0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]} ''' - # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping - # for self._global_to_local_rank_mapping - # the key is the global rank - # the value is the list of local ranks corresponding to the global rank with respect of different axes - # we can see the list of local ranks as the process coordinates for simplicity - # the key and value are all unique, therefore, - # we can also to use the coordinates to find the global rank - - # ========================================================================= - # Step 1 - # find all the process_coordinates for processes in the same process group - # as the given global rank - # ========================================================================= - - # each - processes_in_the_same_process_group = {} - - for dim in range(self.logical_mesh_id.dim()): - # iterate over the dimension size so that we can include all processes - # in the same process group in the given axis - # the _local_rank refers to the local rank of the current process - for _local_rank in range(self.logical_mesh_id.shape[dim]): - - # if this dimension is not initailized yet, - # initialize it with an empty array - if dim not in processes_in_the_same_process_group: - processes_in_the_same_process_group[dim] = [] - - # get the local rank corresponding to the global rank - process_coordinates = self._global_to_local_rank_mapping[global_rank].copy() - - # replace the local rank in the given dimension with the - # lcoal rank of the current process iterated - process_coordinates[dim] = _local_rank - processes_in_the_same_process_group[dim].append(process_coordinates) - - # ================================================================= - # Step 2 - # Use local rank combination to find its corresponding global rank - # ================================================================= - # the key of the dict is the axis - # the value is the list of global ranks which are in the same process group as the given global rank - global_pg_ranks = {} - for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items(): - global_pg_ranks[dim] = [] - for process_coordinates in coordinates_of_all_processes: - # find the global rank by local rank combination - for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items(): - if process_coordinates == _process_coordinates: - global_pg_ranks[dim].append(_global_rank) - return global_pg_ranks - - def flatten(self): - """ - Flatten the logical mesh into an effective 1d logical mesh, - """ - flatten_mesh_shape_size = len(self.mesh_shape) - flatten_mesh_shape = [self.num_devices] - return DeviceMesh(self._physical_mesh_id, - tuple(flatten_mesh_shape), - mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), - mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), - init_process_group=self._init_process_group) + process_groups = {} + for d in range(self.logical_mesh_id.dim()): + for replacer in range(self.logical_mesh_id.shape[d]): + if d not in process_groups: + process_groups[d] = [] + process_group_member = self.convert_map[rank].copy() + process_group_member[d] = replacer + process_groups[d].append(process_group_member) + return process_groups + + def global_rank_to_process_groups_with_global_rank(self, rank): + ''' + Give a global rank and return all process groups of this rank. + for example: + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + mesh_shape = (4, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + print(device_mesh.global_rank_to_process_groups_with_global_rank(0)) + output: + # key is axis name + # value is a list of global ranks in same axis with rank 0 + {0: [0, 4, 8, 12], 1: [0, 1, 2, 3]} + ''' + logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank) + process_groups = {} + for dim, logical_ranks in logical_process_groups.items(): + process_groups[dim] = [] + for logical_rank in logical_ranks: + for g_rank, l_rank in self.convert_map.items(): + if l_rank == logical_rank: + process_groups[dim].append(g_rank) + return process_groups def all_gather_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] @@ -409,3 +212,38 @@ class DeviceMesh: penalty_factor = num_devices / 2.0 return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) + + +class FlattenDeviceMesh(DeviceMesh): + + def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None): + super().__init__(physical_mesh_id, + mesh_shape, + mesh_alpha, + mesh_beta, + init_process_group=False, + need_flatten=False) + # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars + self.mesh_alpha = max(self.mesh_alpha) + self.mesh_beta = min(self.mesh_beta) + # Different from original process_groups_dict, rank_list is not stored + self.process_number_dict = self.create_process_numbers_for_logical_mesh() + + def create_process_numbers_for_logical_mesh(self): + ''' + Build 1d DeviceMesh in column-major(0) and row-major(1) + for example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + ''' + num_devices = reduce(operator.mul, self.mesh_shape, 1) + process_numbers_dict = {} + process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist() + process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist() + return process_numbers_dict + + def mix_gather_cost(self, num_bytes): + num_devices = reduce(operator.mul, self.mesh_shape, 1) + return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index ca8914362..76f550dc4 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,5 +1,5 @@ from types import MethodType -from typing import Callable, Dict, Optional, Union +from typing import Callable, Optional, Union import torch import torch.distributed as dist @@ -8,9 +8,8 @@ from torch import Tensor from torch.utils._pytree import tree_map from colossalai._analyzer._subclasses import MetaTensor -from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.d_tensor.d_tensor import DTensor -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.tensor.d_tensor.layout import Layout # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ @@ -173,7 +172,7 @@ class LazyTensor(torch.Tensor): self.clean() return _convert_cls(self, target) - def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: + def distribute(self, layout: Layout) -> torch.Tensor: """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. Args: @@ -184,7 +183,7 @@ class LazyTensor(torch.Tensor): """ target = self._materialize_data() self.clean() - local_tensor = DTensor(target, device_mesh, sharding_spec).local_tensor + local_tensor = DTensor(target, layout).local_tensor return _convert_cls(self, local_tensor) def clean(self) -> None: @@ -537,10 +536,7 @@ class LazyInitContext: return _apply_to_lazy_module(module, apply_fn, verbose) @staticmethod - def distribute(module: nn.Module, - device_mesh: DeviceMesh, - sharding_spec_dict: Dict[str, ShardingSpec], - verbose: bool = False) -> nn.Module: + def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: @@ -550,7 +546,7 @@ class LazyInitContext: """ def apply_fn(name: str, p: LazyTensor): - p.distribute(device_mesh, sharding_spec_dict[name]) + p.distribute(layout_dict[name]) return _apply_to_lazy_module(module, apply_fn, verbose) diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py index 300baf9c1..394334558 100644 --- a/colossalai/nn/layer/parallel_1d/_operation.py +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -1,6 +1,5 @@ import torch import torch.distributed as dist - from colossalai.core import global_context as gpc try: diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md deleted file mode 100644 index 93a4f1e57..000000000 --- a/colossalai/shardformer/README.md +++ /dev/null @@ -1,296 +0,0 @@ -# ⚡️ ShardFormer - -## 📚 Table of Contents - -- [⚡️ ShardFormer](#️-shardformer) - - [📚 Table of Contents](#-table-of-contents) - - [🔗 Introduction](#-introduction) - - [🔨 Usage](#-usage) - - [🔮 Simple example](#-simple-example) - - [💡 Policy](#-policy) - - [😊 Module](#-module) - - -## 🔗 Introduction - -**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background. - -## 🔨 Usage - -The sample API usage is given below: - -``` python -from colossalai.shardformer import shard_model -from transformers import BertForMaskedLM - -# create huggingface model as normal -model = BertForMaskedLM.from_pretrained("bert-base-uncased") - -# make the huggingface model paralleled to ShardModel -# auto policy: -sharded_model = shard_model(model) - -# custom policy: -from xxx import -sharded_model = shard_model(model, ) - -# do angthing as normal -... -``` - -## 🔮 Simple example - -``` shell -# inference -colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference -# train -colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train -``` - - -## 💡 Policy - -If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. - -You should do: - -1. Inherit Policy class -2. Overwrite `argument_policy` method - - In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. Shardformer will replace all the layer belonging to the class you specified. - - `attr_dict` is dict contains all the attributes need to be modified in this layer. - - `param_funcs` is a list contains some functions which will return the path of the weight and bias from the layer. -3. Overwrite `inject_policy` method (Optional) - - Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method. -4. Overwrite or add the param functions - - These functions use a suffix to record the path of weight or bias for the layer. - - The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively. -5. Overwrite `binding_policy` (Optional) - - Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers. - - This function will return a dict, the key and value are the suffix of weight need to be binded. - -More details can be found in shardformer/policies/basepolicy.py -``` python -from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument - -CustomPolicy(Policy): -@staticmethod - def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: - r""" - Return the dict for the modify policy, the key is the original layer class and the value is the - argument for the modify layer - - Args: - model_config (:class:`tansformer.Config`): The config of transformer model - shard_config (:class:`ShardConfig`): The config for sharding model - - Return: - Dict for the modify policy, - :: - { - origin layer class1 (nn.Module): Argument( - attr_dict = { - argument1: value1, - argument2: value2, - ... - }, - param_funcs = [ - staticmethod1, - staticmethod2, - ... - ] - ), - origin layer class2 (nn.Module): Argument( - attr_dict = { - argument1: value1, - argument2: value2, - ... - }, - param_funcs = [ - staticmethod1, - staticmethod2, - ... - ] - ), - ... - } - - """ - raise NotImplementedError - - @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: - r""" - Return the dict for the inject model - - Return: - The injected model, key is the original model and value is the new shardmodel - :: - (OrignModel, CustomModel) - in `CustomModel`, we can overwrite the forward and backward process - """ - return () - - @staticmethod - def binding_policy() -> Dict: - r""" - Return the dict for the binding model - - Return: - This method should return the binding relationship for some layers share the weight or bias, - the key and value is the suffix of the weight or bias of the model - :: - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - """ - return NotImplementedError - - @staticmethod - def attn_in() -> List: - """ - Attention qkv layer - - Returns: - List[Layer]: List of layer object, each layer is the new - """ - return NotImplementedError - - @staticmethod - def attn_out() -> List: - """ - Attention output projection layer - - Returns: - List[Layer]: List of layer object - """ - return NotImplementedError - - @staticmethod - def mlp_in() -> List: - """ - h -> 4h mlp layer - - Returns: - List[Layer]: List of layer object - """ - return NotImplementedError - - @staticmethod - def mlp_out() -> List: - """ - 4h -> h mlp layer - - Returns: - List[Layer]: List of layer object - """ - return NotImplementedError - - @staticmethod - def embedding() -> List: - """ - Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums - - Return: - List[Layer]: List of layer object - """ - return NotImplementedError - - @staticmethod - def unembedding() -> List: - """ - Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums - - Return: - List[Layer]: List of layer object - """ - return NotImplementedError - -``` - - -## 😊 Module - - 1. Flowchart - -

- -

- - 2. Important Modules - - - CLASS `shard_model`: - - This is the user api to use shardformer, just create a model from transformers and define a custom policy or use shardformer autopolicy to make a shard model. - - - CLASS `Layer`: - - Parameters: - - weight (str): The weight suffix of the layer - - bias (str): The bias suffix of the layer - - replace_layer (:class:`colosalai.nn`): The layer to replace the original layer - - ignore (bool): Whether to ignore this layer if it is not in the model - - This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class. - - CLASS `Col_Layer(Layer)`: - - gather_output (bool): Whether to gather the output of the layer - - This class inherited from `Layer`, representing the layer will be sliced along column. - - CLASS `Row_Layer(Layer)`: - - This class inherited from `Layer`, representing the layer will be sliced along row. - - - CLASS `Policy`: - - In Shardformer, this class holds significant importance as it defines the model partitioning methods, required parameter modifications, and model injection techniques all within a single Policy class. - - `Policy.attn_in()/attn_out()/mlp_in()/mlp_out()/embedding()/unembedding()`...... - - These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions. - - `Policy.argument_policy()` - - In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach. - - `Policy.inject_policy()` - - This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else. - - `Policy.binding_policy()` - - This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters. - - - CLASS `ModelSharder(model, policy)`: - - This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model. - - `ModelShard.inject_model()` - - This function is used to inject the model to modify the forward and backward progress. - - `ModelShard.replace_layer()` - - This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication. - - `ModelShard.bind_layer()` - - This function is used to help different layers share weight or bias. - - - CLASS `Slicer`: - - This class is used to slice tensor according to policy. - - - 3. DistCrossEntropy Loss - - Overview - - In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is: - $$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$ - - alse can be represented as: - - $$ loss = \log(\sum_i\exp(x[i])) - x[class]$$ - - - Step - - - First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large - - - Get a mask to mask the logits not in the local device - - - Caculate the loss according to the second formula diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py deleted file mode 100644 index e817ea3eb..000000000 --- a/colossalai/shardformer/layer/_operation.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch -import torch.distributed as dist - -from colossalai.core import global_context as gpc - -try: - import fused_mix_prec_layer_norm_cuda -except: - fused_mix_prec_layer_norm_cuda = None - - -class FusedLayerNormAffineFunction1D(torch.autograd.Function): - r"""Layernorm - - Args: - input: input matrix. - weight: weight matrix. - bias: bias matrix. - normalized_shape: input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps: a value added to the denominator for numerical stability - """ - - @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): - ctx.normalized_shape = normalized_shape - ctx.eps = eps - input_ = input.contiguous() - weight_ = weight.contiguous() - bias_ = bias.contiguous() - output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, - bias_, ctx.eps) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) - return output - - @staticmethod - def backward(ctx, grad_output): - input_, weight_, bias_, mean, invvar = ctx.saved_tensors - grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias \ - = fused_mix_prec_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) - - return grad_input, grad_weight, grad_bias, None, None - - -class LinearWithAsyncCommunication(torch.autograd.Function): - """ - Linear layer execution with asynchronous communication in backprop. - """ - - @staticmethod - def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): - ctx.save_for_backward(input_, weight) - ctx.use_bias = bias is not None - ctx.parallel_mode = parallel_mode - ctx.async_grad_allreduce = async_grad_allreduce - - output = torch.matmul(input_, weight.t()) - if bias is not None: - output = output + bias - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - use_bias = ctx.use_bias - - total_input = input - grad_input = grad_output.matmul(weight) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) - total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) - - if ctx.async_grad_allreduce: - # Asynchronous all-reduce - handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # all-reduce scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 - - grad_weight = grad_output.t().matmul(total_input) - grad_bias = grad_output.sum(dim=0) if use_bias else None - - if ctx.async_grad_allreduce: - handle.wait() - - return grad_input, grad_weight, grad_bias, None, None, None - - -def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): - return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py deleted file mode 100644 index 186959467..000000000 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Function - - -class DistCrossEntropy(Function): - r""" - Overwrite the forward and backward function to calculate the cross entropy loss before gather - - Args: - Function (:class:`torch.autograd.Function`): default - """ - - @staticmethod - def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor): - r""" - Calculate the cross entropy loss before gather, the origin loss function is as follows: - loss = -log(exp(x[class])/sum(exp(x[i])) - and can be rewrite as: - loss = log(sum(exp(x[i])) - x[class] - - To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i] - - Args: - vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is - [batch_size, seq_len, vocab_size] - labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is - [batch_size, seq_len] - - Returns: - :class:`torch.Tensor`: The cross entropy loss - """ - # get the max - logits_max = torch.max(vocab_logits, dim=-1)[0] - dist.all_reduce(logits_max, op=dist.ReduceOp.MAX) - - # minus the max to avoid the result of sum of exp is too large and the log is nan - vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) - - # mask the target in the local device - partition_vocab_size = vocab_logits.size()[-1] - rank = dist.get_rank() - world_size = dist.get_world_size() - global_vocab_size = partition_vocab_size * world_size - - # [down, up) => false, other device and -100 => true - delta = (global_vocab_size + world_size - 1) // world_size - down_shreshold = rank * delta - up_shreshold = down_shreshold + delta - mask = (target < down_shreshold) | (target >= up_shreshold) - masked_target = target.clone() - down_shreshold - masked_target[mask] = 0 - - # reshape the logist and target - # reshape the vocab_logits to [bath_size * seq_len, vocab_size] - # reshape the labels to [bath_size * seq_len] - logits_2d = vocab_logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - - # extract the x[class] and set the x[other device] to zero - pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), - masked_target_1d] - pred_logits_1d = pred_logits_1d.clone().contiguous() - pred_logits = pred_logits_1d.view_as(target) - pred_logits[mask] = 0.0 - - # allreduce the get all x(i,y) - dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM) - exp_logits = vocab_logits - torch.exp(vocab_logits, out=exp_logits) - sum_exp_logits = torch.sum(exp_logits, dim=-1) - dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM) - - # calculate the loss - # loss = log(sum(exp(x[i]))) - x[class] - loss = torch.log(sum_exp_logits) - pred_logits - loss = torch.sum(loss).div_(loss.numel()) - - # caculate the softmax - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - ctx.save_for_backward(exp_logits, mask, masked_target_1d) - - return loss - - @staticmethod - def backward(ctx, grad_output): - # retrieve the saved tensors - exp_logits, mask, masked_target_1d = ctx.saved_tensors - - # use exp logits as the input grad - grad_logits = exp_logits - partion_vocab_size = grad_logits.shape[-1] - grad_logits_2d = grad_logits.view(-1, partion_vocab_size) - - update = 1.0 - mask.view(-1).float() - grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update - - grad_logits.mul_(grad_output.unsqueeze(dim=-1)) - return grad_logits, None, None - - -def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels) diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py deleted file mode 100644 index acc114029..000000000 --- a/colossalai/shardformer/layer/dropout.py +++ /dev/null @@ -1,58 +0,0 @@ -import os -import time -from contextlib import contextmanager - -import torch -import torch.nn as nn - - -class SeedManager: - """ - This class is a random state manager to change random state for different random seed. - - """ - - def __init__(self): - original_state = torch.cuda.get_rng_state() - seed = int(f"{int(time.time())}{os.environ['RANK']}") - torch.cuda.manual_seed(int(seed)) - self.dropout_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(original_state) - - def set_mode(self, rng_state): - torch.cuda.set_rng_state(rng_state) - - def get_current_mode(self): - current_state = torch.cuda.get_rng_state() - return current_state - - @contextmanager - def dropout_mode(self): - """ - This is a context manager to change the dropout state and recover the original state. - - Usage: - :: - >>> with _seed_manager.dropout_mode(): - >>> input = super().forward(input) - """ - try: - current_mode = self.get_current_mode() - yield self.set_mode(self.dropout_state) - finally: - self.dropout_state = self.get_current_mode() - self.set_mode(current_mode) - - -_seed_manager = SeedManager() - - -class Dropout1D(nn.Dropout): - - def __init__(self, p=0.5, inplace=False): - super().__init__(p, inplace) - - def forward(self, input): - with _seed_manager.dropout_mode(): - input = super().forward(input) - return input diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py deleted file mode 100644 index f5123885b..000000000 --- a/colossalai/shardformer/layer/layers.py +++ /dev/null @@ -1,1043 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import math -from collections import OrderedDict -from typing import Callable, Tuple - -import torch -import torch.nn.functional as F -from torch import Tensor -from torch.nn.parameter import Parameter - -from colossalai.communication import broadcast -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.kernel import LayerNorm -from colossalai.nn import init as init -from colossalai.nn.layer.base_layer import ParallelLayer -from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule -from colossalai.nn.layer.parallel_1d._utils import ( - gather_forward_split_backward, - get_parallel_input, - reduce_grad, - reduce_input, - set_parallel_input, - split_forward_gather_backward, -) -from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition -from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding -from colossalai.registry import LAYERS -from colossalai.utils.checkpointing import ( - broadcast_state_dict, - gather_tensor_parallel_state_dict, - partition_tensor_parallel_state_dict, -) -from colossalai.utils.cuda import get_current_device - -from ._operation import linear_with_async_comm - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -# @LAYERS.register_module -class Linear1D(ColossalaiModule): - r"""Linear layer for 1D parallelism. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - gather_output (bool, optional): Whether to call all-gather on output, defaults to False. - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - parallel_input = get_parallel_input() - if not parallel_input and not gather_output: - layer = Linear1D_Col(in_features, - out_features, - bias=bias, - dtype=dtype, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) - else: - layer = Linear1D_Row(in_features, - out_features, - bias=bias, - dtype=dtype, - parallel_input=parallel_input, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) - super().__init__(layer) - - -# @LAYERS.register_module -class LayerNorm1D(ColossalaiModule): - r""" - Layer Normalization for colossalai - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - _fast_ln_supported_sizes = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536 - ] - - def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: - norm = Fast_LN(normalized_shape, eps=eps).to(dtype) - else: - norm = None - try: - from apex.normalization import FusedLayerNorm - norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) - except ImportError: - norm = LayerNorm(normalized_shape, eps=eps).to(dtype) - super().__init__(norm) - - def _load_from_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) - super()._load_from_state_dict(local_state, prefix, *args) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - super()._save_to_state_dict(destination, prefix, keep_vars) - - -# @LAYERS.register_module -class Classifier1D(ParallelLayer): - r"""RowLinear with given weight. Classifier of 1D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - self.parallel_input = get_parallel_input() - - # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs)) - self.has_weight = True - if bias: - self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = False - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.num_classes - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) - - def _set_tensor_parallel_attributes(self): - if self.has_weight: - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) - input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) - - output_parallel = F.linear(input_, self.weight) - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - if self.bias is not None: - output = output + self.bias - return output - - -# @LAYERS.register_module -class VocabParallelClassifier1D(ParallelLayer): - r"""ColLinear with given weight. Classifier of 1D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - self.gather_output = gather_output - self.parallel_input = get_parallel_input() - - # Divide the weight matrix along the last dimension. - self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size) - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs)) - self.has_weight = True - if bias: - self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = True - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.num_classes - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight, self.bias) - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - else: - output = output_parallel - return output - - -# @LAYERS.register_module -class Linear1D_Col(ParallelLayer): - r"""Linear layer with column parallelism. - - The linear layer is defined as :math:`Y = XA + b`. A is parallelized along - its second dimension as :math:`A = [A_1, ..., A_p]`. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - gather_output (bool, optional): If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is :math:`Y_i = XA_i`, defaults to False - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - # self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size) - self.out_features_per_partition = out_features - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) - - if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - is_parallel_output = not self.gather_output - set_parallel_input(is_parallel_output) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - input_parallel = input_ - # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - # output_parallel = F.linear(input_parallel, self.weight, bias) - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True) - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - else: - output = output_parallel - - if self.skip_bias_add: - return output, self.bias - else: - return output - - -# @LAYERS.register_module -class Linear1D_Row(ParallelLayer): - r""" Linear layer with row parallelism - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False. - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): - super().__init__() - - self.stream_chunk_num = stream_chunk_num - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.parallel_input = parallel_input - self.skip_bias_add = skip_bias_add - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - # Divide the weight matrix along the last dimension. - # self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size) - self.input_size_per_partition = in_features - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) - - if self.stream_chunk_num > 1: - # TODO() work for inference only - self.chunk_weight() - if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - - def chunk_weight(self): - self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) - input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) - - if self.stream_chunk_num > 1: - if self.training: - raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") - with torch.no_grad(): - output_parallel_list = [None for i in range(self.stream_chunk_num)] - handle_list = [] - for i in range(self.stream_chunk_num): - output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=gpc.get_group(ParallelMode.PARALLEL_1D), - async_op=True) - handle_list.append(handle) - # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) - for handle in handle_list: - handle.wait() - output = torch.cat(output_parallel_list, dim=-1) - else: - output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - if not self.skip_bias_add: - if self.bias is not None: - output = output + self.bias - return output - else: - return output, self.bias - - -# @LAYERS.register_module -class Embedding1D(ParallelLayer): - r"""Embedding for 1D parallelism. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about ``initializer`` please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size) - - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - - self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) - - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - - output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - - return output - - -# @LAYERS.register_module -class VocabParallelEmbedding1D(ParallelLayer): - r"""Embedding parallelized in the vocabulary dimension. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - - tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - # self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.num_embeddings_per_partition = num_embeddings - self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition - self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - - self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)) - - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = True - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: - with torch.no_grad(): - self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) - - # Mask the output embedding. - output_parallel[input_mask, :] = 0. - # Reduce across all the model parallel GPUs. - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - return output - - -# @LAYERS.register_module -class Dropout1D(ParallelLayer): - """Dropout layer of 1D parallelism. - - Args: - p (float, optional): probability of an element to be zeroed, defaults 0.5. - inplace (bool, optional): whether to do dropout in-place, default to be False. - """ - - def __init__(self, p: float = 0.5, inplace: bool = False): - super().__init__() - self.parallel_input = get_parallel_input() - self.p = p - self.inplace = inplace - - def forward(self, input_: Tensor) -> Tensor: - if self.parallel_input: - with seed(ParallelMode.TENSOR): - output = F.dropout(input_, self.p, self.training, self.inplace) - else: - output = F.dropout(input_, self.p, self.training, self.inplace) - return output - - -# @LAYERS.register_module -class PatchEmbedding1D(ColossalaiModule): - """ - 2D Image to Patch Embedding - - :param img_size: image size - :type img_size: int - :param patch_size: patch size - :type patch_size: int - :param in_chans: number of channels of input image - :type in_chans: int - :param embed_size: size of embedding - :type embed_size: int - :param dtype: The dtype of parameters, defaults to None - :type dtype: torch.dtype, optional - :param flatten: whether to flatten output tensor, defaults to True - :type flatten: bool, optional - :param weight_initializer: The initializer of weight, defaults to kaiming uniform initializer - :type weight_initializer: typing.Callable, optional - :param bias_initializer: The initializer of bias, defaults to xavier uniform initializer - :type bias_initializer: typing.Callable, optional - :param position_embed_initializer: The initializer of position embedding, defaults to zero - :type position_embed_initializer: typing.Callable, optional - """ - - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - dtype: torch.dtype = None, - flatten: bool = True, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): - embed = VanillaPatchEmbedding(img_size, - patch_size, - in_chans, - embed_size, - dtype=dtype, - flatten=flatten, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - position_embed_initializer=position_embed_initializer) - super().__init__(embed) - - def _load_from_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - for key in param_keys: - param = state_dict.pop(key, None) - if param is not None: - local_state[key] = param - - local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) - super()._load_from_state_dict(local_state, prefix, *args) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/shardformer/model/__init__.py b/colossalai/shardformer/model/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/colossalai/shardformer/model/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py deleted file mode 100644 index bd07ab80c..000000000 --- a/colossalai/shardformer/model/modeling_bert.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Any, Dict, List, Type - -import torch -import torch.nn as nn -from torch.nn import CrossEntropyLoss -from transformers import BertForMaskedLM -from transformers.models.bert.modeling_bert import MaskedLMOutput - -from ..layer.dist_crossentropy import applyDistCrossEntropy - - -class BertForMaskedLM_(BertForMaskedLM): - - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs, - ): - # print("[Inject OK] Injected forward method") - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - masked_lm_loss = None - - if labels is not None: - masked_lm_loss = applyDistCrossEntropy(prediction_scores, labels) - # if labels is not None: - # loss_fct = CrossEntropyLoss() # -100 index = padding token - # masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return MaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/colossalai/shardformer/policies/__init__.py b/colossalai/shardformer/policies/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py deleted file mode 100644 index 54cc63ba1..000000000 --- a/colossalai/shardformer/policies/autopolicy.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch.nn as nn - - -def build_policies(): - r""" - Build the policies for the model - - Return: - The dict for the policies - """ - auto_policy_dict = {} - - from transformers import BertForMaskedLM - - from .bert import BertForMaskedLMPolicy - auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy - - from transformers import BertForSequenceClassification - - from .bert import BertForSequenceClassificationPolicy - auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy - - from transformers import GPT2Model - - from .gpt2 import GPT2Policy - auto_policy_dict[GPT2Model] = GPT2Policy - - from transformers import GPT2LMHeadModel - - from .gpt2 import GPT2LMHeadModelPolicy - auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy - - return auto_policy_dict - - -def get_autopolicy(model: nn.Module): - r""" - Return the auto policy for the model - - Args: - model (:class:`nn.Module`): The model to get the auto policy - - Return: - :class:`Policy`: The auto policy for the model - """ - auto_policy_dict = build_policies() - policy = auto_policy_dict.get(model.__class__, None) - if policy is None: - raise NotImplementedError( - f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}" - ) - return policy - - -# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining -# model = BertForPreTraining -# policy = get_autopolicy(model) -# print(policy) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py deleted file mode 100644 index 644d115a2..000000000 --- a/colossalai/shardformer/policies/basepolicy.py +++ /dev/null @@ -1,217 +0,0 @@ -# part of code modified from https://github.com/tunib-ai/parallelformers - -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Tuple, Type - -import torch.nn as nn - - -@dataclass -class Argument: - r""" - The argument class for the policy - - Args: - attr_dict (Dict[str, Any]): The dict for the param setting - param_funcs (:class:`List[Callable]`): The list for the param functions - """ - attr_dict: Dict[str, Any] - param_funcs: List[Callable] - - -@dataclass -class Layer: - r""" - The layer object for the policy - - Args: - weight (str): The weight suffix of the layer - bias (str): The bias suffix of the layer - replace_layer (:class:`colosalai.nn`): The layer to replace the original layer - ignore (bool): Whether to ignore this layer if it is not in the model - reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], - but in GPT2 `Conv1D` layer is [in, out] which is reversed. - n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, - but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and - each device should have a part of Q, K and V weight. - """ - weight: str = None - bias: str = None - replace_layer: Any = None - ignore: bool = False - reversed: bool = False - n_cast: int = None - - -@dataclass -class Col_Layer(Layer): - r""" - Class for col shard layer in MegatronLM - - Args: - gather_output (bool): Whether to gather the output of the layer - """ - gather_output: bool = False - - -@dataclass -class Row_Layer(Layer): - r""" - Class for col shard layer in MegatronLM - """ - pass - - -class Policy(): - r""" - The base class for all the policies - For each different model, it should have a different policy class, like BertPolicy for Bert Model - or OPTPolicy for OPT model. - AutoPolicy: - Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None - to use the auto policy. In shardformer autopolicy, we define a base policy for one type model, - like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, - BertForSequenceClassification, etc., for each different Bert model we difine different policy class - and overwrite the method like ``inject_policy`` to modify the forward and backward process. - - CustomPolicy: - If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite - all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy`` - class for the example. - - """ - - @staticmethod - def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: - r""" - Return the dict for the modify policy, the key is the original layer class and the value is the - argument for the modify layer - - Args: - model_config (:class:`tansformer.Config`): The config of transformer model - shard_config (:class:`ShardConfig`): The config for sharding model - - Return: - Dict for the modify policy, - :: - { - origin layer class1 (nn.Module): Argument( - attr_dict = { - argument1: value1, - argument2: value2, - ... - }, - param_funcs = [ - staticmethod1, - staticmethod2, - ... - ] - ), - origin layer class2 (nn.Module): Argument( - attr_dict = { - argument1: value1, - argument2: value2, - ... - }, - param_funcs = [ - staticmethod1, - staticmethod2, - ... - ] - ), - ... - } - - """ - raise NotImplementedError - - @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: - r""" - Return the dict for the inject model - - Return: - The injected model, key is the original model and value is the new shardmodel - :: - (OrignModel, CustomModel) - in `CustomModel`, we can overwrite the forward and backward process - """ - return None - - @staticmethod - def binding_policy() -> Dict: - r""" - Return the dict for the binding model - - Return: - This method should return the binding relationship for some layers share the weight or bias, - the key and value is the suffix of the weight or bias of the model - :: - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - """ - return None - - @staticmethod - def attn_in() -> List: - r""" - Attention qkv layer - In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be - ``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters - in ``Layer`` object can refer to the ``Layer`` class. - - Returns: - List[Layer]: List of layer object, each layer is the new - """ - return NotImplementedError - - @staticmethod - def attn_out() -> List: - r""" - Attention output projection layer - - Returns: - List[Layer]: List of layer object - """ - return NotImplementedError - - @staticmethod - def mlp_in() -> List: - r""" - h -> 4h mlp layer - - Returns: - List[Layer]: List of layer object - """ - return NotImplementedError - - @staticmethod - def mlp_out() -> List: - r""" - 4h -> h mlp layer - - Returns: - List[Layer]: List of layer object - """ - return NotImplementedError - - @staticmethod - def embedding() -> List: - r""" - Partially slice the embedding layer - - Return: - List[Layer]: List of layer object - """ - return NotImplementedError - - @staticmethod - def unembedding() -> List: - r""" - Partially slice the embedding layer - - Return: - List[Layer]: List of layer object - """ - return None diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py deleted file mode 100644 index 89b32f065..000000000 --- a/colossalai/shardformer/policies/bert.py +++ /dev/null @@ -1,170 +0,0 @@ -from typing import Any, Callable, Dict, List, Tuple, Type - -import torch.nn as nn -from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead - -import colossalai.shardformer.layer.layers as col_nn - -from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer - - -class BertPolicy(Policy): - - @staticmethod - def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: - return { - BertLayer: - Argument( - attr_dict={ - # 1. shard hidden size - "attention.self.all_head_size": config.hidden_size // world_size, - "crossattention.self.all_head_size": config.hidden_size // world_size, - # 2. shard number of heads - "attention.self.num_attention_heads": config.num_attention_heads // world_size, - "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, - }, - param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]), - BertEmbeddings: - Argument( - attr_dict={ - # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - "word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size, - }, - param_funcs=[ - BertPolicy.embedding, - ]), - BertLMPredictionHead: - Argument( - attr_dict={ - # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - }, - param_funcs=[ - BertPolicy.unembedding, - ]) - } - - @staticmethod - def binding_policy() -> Dict: - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - - @staticmethod - def attn_in() -> List: - return [ - Col_Layer( - weight="attention.self.query.weight", - bias="attention.self.query.bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - weight="attention.self.key.weight", - bias="attention.self.key.bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - weight="attention.self.value.weight", - bias="attention.self.value.bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - weight="crossattention.self.query.weight", - bias="crossattention.self.query.bias", - replace_layer=col_nn.Linear1D_Col, - ignore=True, - ), - Col_Layer( - weight="crossattention.self.key.weight", - bias="crossattention.self.key.bias", - replace_layer=col_nn.Linear1D_Col, - ignore=True, - ), - Col_Layer( - weight="crossattention.self.value.weight", - bias="crossattention.self.value.bias", - replace_layer=col_nn.Linear1D_Col, - ignore=True, - ), - ] - - @staticmethod - def attn_out() -> List: - return [ - Row_Layer( - weight="attention.output.dense.weight", - bias="attention.output.dense.bias", - replace_layer=col_nn.Linear1D_Row, - ), - Row_Layer( - weight="crossattention.output.dense.weight", - bias="crossattention.output.dense.bias", - replace_layer=col_nn.Linear1D_Row, - ignore=True, - ), - ] - - @staticmethod - def mlp_in() -> List: - return [ - Col_Layer( - weight="intermediate.dense.weight", - bias="intermediate.dense.bias", - replace_layer=col_nn.Linear1D_Col, - ), - ] - - @staticmethod - def mlp_out() -> List: - return [ - Row_Layer( - weight="output.dense.weight", - bias="output.dense.bias", - replace_layer=col_nn.Linear1D_Row, - ), - ] - - @staticmethod - def embedding() -> List: - return [Col_Layer( - weight="word_embeddings.weight", - replace_layer=col_nn.VocabParallelEmbedding1D, - )] - - @staticmethod - def unembedding() -> List: - return [ - Col_Layer( - weight="decoder.weight", - bias="decoder.bias", - replace_layer=col_nn.Linear1D_Col, - # gather_output=True, - ) - ] - - -from transformers import BertForMaskedLM - -from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ - - -class BertForMaskedLMPolicy(BertPolicy): - - @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: - return (BertForMaskedLM, BertForMaskedLM_) - - -class BertForSequenceClassificationPolicy(BertPolicy): - - @staticmethod - def inject_policy() -> Dict: - return {} - - -# model = BertForMaskedLM.from_pretrained("bert-base-uncased") -# _ = BertForMaskedLMPolicy(model) -# print(isinstance(model,list(_.inject_policy().keys())[0])) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py deleted file mode 100644 index 44dc9c72f..000000000 --- a/colossalai/shardformer/policies/gpt2.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import Any, Callable, Dict, List, Tuple, Type - -import torch.nn as nn -from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model - -import colossalai.shardformer.layer.layers as col_nn - -from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer - - -class GPT2Policy(Policy): - - @staticmethod - def argument_policy(config, world_size): - return { - GPT2Model: - Argument(attr_dict={}, param_funcs=[ - GPT2Policy.embedding, - ]), - GPT2Block: - Argument( - attr_dict={ - # 1. reduce hidden size - "attn.embed_dim": config.hidden_size // world_size, - "attn.split_size": config.hidden_size // world_size, - "crossattention.embed_dim": config.hidden_size // world_size, - "crossattention.split_size": config.hidden_size // world_size, - # 2. reduce number of heads - "attn.num_heads": config.num_attention_heads // world_size, - "crossattention.num_heads": config.num_attention_heads // world_size, - }, - param_funcs=[ - GPT2Policy.attn_in, - GPT2Policy.attn_out, - GPT2Policy.mlp_in, - GPT2Policy.mlp_out, - ]), - } - - @staticmethod - def attn_in() -> List: - return [ - Col_Layer(weight="attn.c_attn.weight", - bias="attn.c_attn.bias", - n_cast=3, - reversed=True, - replace_layer=col_nn.Linear1D_Col), - Col_Layer(weight="crossattention.c_attn.weight", - bias="crossattention.c_attn.bias", - n_cast=2, - reversed=True, - ignore=True, - replace_layer=col_nn.Linear1D_Col), - Col_Layer(weight="crossattention.q_attn.weight", - bias="crossattention.q_attn.bias", - reversed=True, - ignore=True, - replace_layer=col_nn.Linear1D_Col) - ] - - @staticmethod - def attn_out() -> List: - return [ - Row_Layer(weight="attn.c_proj.weight", - bias="attn.c_proj.bias", - reversed=True, - replace_layer=col_nn.Linear1D_Row), - Row_Layer(weight="crossattention.c_proj.weight", - bias="crossattention.c_proj.bias", - reversed=True, - ignore=True, - replace_layer=col_nn.Linear1D_Row) - ] - - @staticmethod - def mlp_in() -> List: - return [ - Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col), - ] - - @staticmethod - def mlp_out() -> List: - return [ - Row_Layer(weight="mlp.c_proj.weight", - bias="mlp.c_proj.bias", - reversed=True, - replace_layer=col_nn.Linear1D_Row) - ] - - @staticmethod - def embedding() -> List: - return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)] - - -from transformers import GPT2LMHeadModel - - -class GPT2LMHeadModelPolicy(GPT2Policy): - - @staticmethod - def argument_policy(config, world_size): - base_argument = GPT2Policy.argument_policy(config, world_size) - argument = { - GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[ - GPT2LMHeadModelPolicy.unembedding, - ]), - } - argument.update(base_argument) - return argument - - @staticmethod - def unembedding() -> List: - return [ - Col_Layer(weight="lm_head.weight", - bias="lm_head.bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True) - ] diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py deleted file mode 100644 index d5f70163a..000000000 --- a/colossalai/shardformer/shard/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .shard_config import ShardConfig -from .sharder import ModelSharder, shard_model -from .slicer import Slicer - -__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer'] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py deleted file mode 100644 index 4cf9162b9..000000000 --- a/colossalai/shardformer/shard/shard_config.py +++ /dev/null @@ -1,20 +0,0 @@ -from dataclasses import dataclass - -__all__ = ['ShardConfig'] - - -@dataclass -class ShardConfig: - """ - The config for sharding the huggingface model for test - """ - rank: int - fp16: bool = True - num_gpus: int = 2 - world_size: int = 2 - backend = "nccl" - verbose: str = 'simple' - seed: int = None - require_grad: bool = False - master_addr: str = "127.0.0.1" - master_port: int = 29500 diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py deleted file mode 100644 index 1ada75e06..000000000 --- a/colossalai/shardformer/shard/sharder.py +++ /dev/null @@ -1,266 +0,0 @@ -from typing import Any, Callable, Dict, List - -import torch -import torch.nn as nn -from transformers.pytorch_utils import Conv1D - -from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Policy -from ..utils.utils import getattr_, hasattr_, setattr_ -from .shard_config import ShardConfig -from .slicer import Slicer - -__all__ = ['ModelSharder', 'shard_model'] - - -class ModelSharder(object): - r""" - Shard the original huggingface model according to the policy - - Args: - policy (:class:`Policy`): The policy to shard the model - model (:class:`torch.Module`): The model to shard - shard_config: The setting of distributed model - """ - - def __init__( - self, - model: nn.Module, - policy: Policy, - shard_config: ShardConfig = None, # TODO - ) -> None: - self.model = model - self.policy = get_autopolicy(self.model) if policy is None else policy - self.slicer = Slicer(shard_config) - self.shard_config = shard_config - self.model_config = self.model.config - - def shard(self) -> None: - self.reshape_embedding() - self.inject_model(self.model) - self.replace_layer(self.model) - self.bind_layer(self.model) - - def reshape_embedding(self,) -> None: - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - vocab_size = self.model_config.vocab_size - world_size = self.shard_config.world_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - self.model_config = self.model.config - - def inject_model( - self, - model: nn.Module, - ) -> None: - r""" - Replace the model to policy defined model - Mainly modify the forward and backward to fit distributed model - - e.g. - :: - BertForMaskedLM.forward -> BertForMaskedLM_.forward - """ - inject_policy = self.policy.inject_policy() - - if inject_policy is None: - return - org_model_cls = inject_policy[0] - shard_model_cls = inject_policy[1] - - if model.__class__ == org_model_cls: - for key in shard_model_cls.__dict__.keys(): - if hasattr(model.__class__, key): - setattr( - model.__class__, - key, - getattr(shard_model_cls, key), - ) - else: - raise NotImplementedError(f"{model.__class__} is not implemented so far") - - def replace_layer( - self, - model: nn.Module, - ) -> None: - r""" - Replace the layer according to the policy, and replace the layer one by one - - Args: - model (:class:`torch.nn.Module`): The layer to shard - """ - argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size) - for argument_policy in argument_policies.items(): - origin_layer_cls = argument_policy[0] - attr_dict = argument_policy[1].attr_dict - param_funcs = argument_policy[1].param_funcs - self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) - - def traverse_replace_layer( - self, - layer: nn.Module, - origin_cls: nn.Module, - attr_dict: Dict[str, Any], - param_funcs: List[Callable], - ) -> None: - r""" - Reverse the replace layer operation - - Args: - layer (:class:`torch.nn.Module`): The object of layer to shard - origin_cls (:class:`transformers.model`): The origin layer class - attr_dict (Dict): The attribute dict to modify - policy_cls (:class:`Policy`): The policy class - """ - if layer.__class__ == origin_cls: - for k, v in attr_dict.items(): - setattr_(layer, k, v, ignore=True) - self.shard_one_layer(layer, param_funcs) - for name, child in layer.named_children(): - self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs) - return layer - - def shard_one_layer( - self, - org_layer: nn.Module, - param_funcs: List[Callable], - ) -> None: - r""" - Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict - - Args: - org_layer (:class:`torch.nn.Module`): The origin layer object to shard - param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class - - """ - for func in param_funcs: - policy_layers = func() - for policy_layer in policy_layers: - weight = None - bias = None - weight_attr = policy_layer.weight - bias_attr = policy_layer.bias - replace_layer_cls = policy_layer.replace_layer - ignore = policy_layer.ignore - n_cast = policy_layer.n_cast - reversed = policy_layer.reversed - if policy_layer.__class__.__name__ == "Col_Layer": - gather_output = policy_layer.gather_output - - if weight_attr is not None: - if hasattr_(org_layer, weight_attr): - weight = getattr_(org_layer, weight_attr) - elif not ignore: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") - - if bias_attr is not None: - if hasattr_(org_layer, bias_attr): - bias = getattr_(org_layer, bias_attr) - elif not ignore: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") - - # dont have the attribute in policy, and ignore is true - if weight is None and bias is None and ignore: - continue - - # set the sliced weight and bias to the new nn_col layer - assert weight is not None or bias is not None - layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) - - # slice weight and bias - weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) - - # create new object to replace the origin layer - if replace_layer_cls is not None: - if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)): - if replace_layer_cls.__name__ == "Linear1D_Row": - replace_layer = replace_layer_cls(weight.shape[1], - weight.shape[0], - bias=False if bias is None else True) - elif replace_layer_cls.__name__ == "Linear1D_Col": - replace_layer = replace_layer_cls(weight.shape[0], - weight.shape[1], - bias=False if bias is None else True, - gather_output=gather_output) - setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) - self.set_param(replace_layer, weight, bias) - elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): - replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], - getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) - setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) - self.set_param(replace_layer, weight, bias) - else: - raise NotImplementedError( - f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") - # do not replace the layer object, just replace the weight and bias - else: - self.set_param(org_layer, layer_attr, weight, bias) - - def set_param(self, - layer: Any, - weight: torch.Tensor = None, - bias: torch.Tensor = None, - layer_attr: str = "") -> None: - r""" - Reset the weight and bias of the layer object - - Args: - layer (:class:`torch.nn.Module`): The layer object - layer_attr (str): The attribute name of the layer - weight (:class:`torch.Tensor`): The weight of the layer - bias (:class:`torch.Tensor`): The bias of the layer - """ - assert weight is not None or bias is not None - if weight is not None: - setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous())) - self.set_layer_size(layer, layer_attr, weight.shape) - if bias is not None: - setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous())) - - def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None: - r""" - Set the layer attribute - - Args: - layer (:class:`torch.nn.Module`): The layer object - layer_attr (str): The attribute name of the layer - size (:class:`torch.Size`): The size of the tensor - """ - # Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features - attrs = ["out_features", "in_features"] - for i, attr in enumerate(attrs): - if hasattr_(layer, f"{layer_attr}.{attr}"): - setattr_(layer, f"{layer_attr}.{attr}", size[i]) - - def bind_layer(self, model: nn.Module) -> None: - r""" - Bind the layer according to the binding policy - - Args: - model (:class:`torch.nn.Module`): The shard model - """ - binding_map = self.policy.binding_policy() - if binding_map is None: - return - for k, v in binding_map.items(): - param = getattr_(model, k) - param = nn.Parameter(param) - setattr_(model, k, param) - setattr_(model, v, param) - - -def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None): - r""" - The function is used to shard the PyTorch model. - - Args: - model (`torch.nn.Model`): the origin huggingface model - shard_config (`ShardConfig`): the config for distribute information - policy (`Policy`): the custom policy for sharding - """ - sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy) - sharder.shard() - return model diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py deleted file mode 100644 index 6d35bd193..000000000 --- a/colossalai/shardformer/shard/slicer.py +++ /dev/null @@ -1,161 +0,0 @@ -import torch - -from ..policies.basepolicy import Col_Layer, Layer, Row_Layer -from .shard_config import ShardConfig - -dim_mapping = {Col_Layer: 1, Row_Layer: 0} - - -class Slicer(): - - def __init__( - self, - shardconfig: ShardConfig #TODO - ) -> None: - self.shardconfig = shardconfig - - def slice_weight_bias( - self, - weight: torch.Tensor, - bias: torch.Tensor, - policy_layer_cls: Layer, - n_cast: int = None, - reversed: bool = False, - ): - r""" - Slice the weight and bias according to policy layer cls - ``Layer`` -> do nothing - ``Col_Layer`` -> slice the weight and bias along dim 1 - ``Row_Layer`` -> slice the weight along dim 0 and do not slice bias - - Args: - weight (:class:`torch.nn.Module`): The weight of the layer - bias: (:class:`torch.nn.Module`): The bias of the layer - policy_layer_class (:class:`Policy`): The class represent how to slice the tensor - """ - if policy_layer_cls == Layer: - return weight, bias - - dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls]) - # print(weight.shape, dim) - if policy_layer_cls == Col_Layer: - weight = self.slice_tensor(weight, dim, False, n_cast) - bias = self.slice_tensor(bias, 0, True) - elif policy_layer_cls == Row_Layer: - weight = self.slice_tensor(weight, dim, False, n_cast) - else: - raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") - if reversed: - weight = weight.transpose(0, 1).contiguous() - return weight, bias - - def slice_tensor( - self, - tensor_in: torch.Tensor, - dim: int, - is_bias: bool, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice tensor according to the config - - Args: - tensor_in (:class:`torch.Tensor`): The tensor to slice - dim (int): The dimension to slice - is_bias (bool): Whether the tensor is bias - """ - if tensor_in is None: - return None - if not is_bias: - return self.slice_2d(tensor_in, dim, n_cast) - else: - return self.slice_1d(tensor_in, n_cast) - - def slice_2d( - self, - tensor: torch.Tensor, - dim: int, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the 2D tensor - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - dim (int): The dimension to slice - """ - assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor" - if dim == 0: - return self.slice_row(tensor, n_cast) - elif dim == 1: - return self.slice_col(tensor, n_cast) - - def slice_1d( - self, - tensor: torch.Tensor, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the 1D tensor - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - - Returns: - :class:`torch.Tensor`: The sliced tensor - """ - if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() - else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) - chunk_list = [ - tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) - ] - return torch.cat(chunk_list, dim=0).contiguous() - - def slice_col( - self, - tensor: torch.Tensor, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the tensor in column - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - - Returns: - :class:`torch.Tensor`: The sliced tensor - - """ - if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() - else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) - chunk_list = [ - tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) - ] - return torch.cat(chunk_list, dim=0).contiguous() - - def slice_row( - self, - tensor: torch.Tensor, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the tensor in column - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - - Returns: - :class:`torch.Tensor`: The sliced tensor - """ - if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() - else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1) - chunk_list = [ - tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) - ] - return torch.cat(chunk_list, dim=1).contiguous() diff --git a/colossalai/shardformer/test/config.py b/colossalai/shardformer/test/config.py deleted file mode 100644 index 2b80d8b3c..000000000 --- a/colossalai/shardformer/test/config.py +++ /dev/null @@ -1 +0,0 @@ -parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')) diff --git a/colossalai/shardformer/test/module_test.py b/colossalai/shardformer/test/module_test.py deleted file mode 100644 index 83dc7ec6c..000000000 --- a/colossalai/shardformer/test/module_test.py +++ /dev/null @@ -1,50 +0,0 @@ -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import colossalai -from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy -from colossalai.shardformer.layer.dropout import Dropout1D - - -def get_args(): - parser = colossalai.get_default_parser() - parser.add_argument("--module", type=str, default='distloss') - return parser.parse_args() - - -def test_dist_crossentropy(): - pred = torch.randn(2, 4, 8, requires_grad=True) - labels = torch.randint(8, (1, 4)).repeat(2, 1) - - pred_ = pred.view(-1, 8) - labels_ = labels.view(-1) - loss = F.cross_entropy(pred_, labels_) - loss.backward() - print(f"normal loss:{loss}") - - pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])] - loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda')) - loss.backward() - print(f"dist loss:{loss}") - - -def test_dropout(): - input = torch.randn(5, 4).to("cuda") - m = Dropout1D(p=0.2).to("cuda") - for i in range(2): - print(f"Output: {m(input)}") - print(torch.randn(1)) - - -if __name__ == '__main__': - args = get_args() - colossalai.launch_from_torch(config={}) - if args.module == 'distloss': - test_dist_crossentropy() - elif args.module == 'dropout': - test_dropout() - else: - print("not implemented yet") diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py deleted file mode 100644 index e2d5a94c7..000000000 --- a/colossalai/shardformer/test/test.py +++ /dev/null @@ -1,124 +0,0 @@ -import os -import random - -import torch -import torch.nn as nn -from datasets import load_dataset -from torch.utils.data import DataLoader -from tqdm.auto import tqdm -from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler - -import colossalai -from colossalai.shardformer.shard import ShardConfig, shard_model -from colossalai.utils import get_current_device, print_rank_0 - -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' - - -def get_args(): - parser = colossalai.get_default_parser() - parser.add_argument("--mode", type=str, default='inference') - parser.add_argument("--save_model", action='store_true') - parser.add_argument("--model", type=str, default='bert-base-uncased') - return parser.parse_args() - - -def load_data(args): - tokenizer = AutoTokenizer.from_pretrained(args.model) - if tokenizer.pad_token is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - # tokenizer.pad_token_id = 0 - datasets = load_dataset('wikitext', 'wikitext-2-raw-v1') - # datasets=load_dataset("yelp_review_full") - tokenized_datasets = datasets.map( - lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True) - tokenized_datasets = tokenized_datasets.remove_columns(["text"]) - # tokenized_datasets=tokenized_datasets.rename_column("label","labels") - tokenized_datasets.set_format("torch") - - train_dataset = tokenized_datasets["train"] - test_dataset = tokenized_datasets["test"] - - datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt") - train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) - eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) - return train_dataloader, eval_dataloader - - -def inference(model: nn.Module, args): - print(model) - # print(model.wte.weight.shape) - tokenizer = AutoTokenizer.from_pretrained(args.model) - if tokenizer.pad_token is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - tokenizer.pad_token_id = 0 - token = "Hello, my dog is cute" - inputs = tokenizer(token, return_tensors="pt") - inputs.to("cuda") - model.eval() - model.to("cuda") - outputs = model(**inputs) - print(outputs[0]) - - -def train(model: nn.Module, args, num_epoch: int = 3): - train_dataloader, eval_dataloader = load_data(args) - optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) - num_training = num_epoch * len(train_dataloader) - progress_bar = tqdm(range(num_training)) - lr_scheduler = get_scheduler(name="linear", - optimizer=optimizer, - num_warmup_steps=0, - num_training_steps=num_training) - best_test_loss = float("inf") - model.to("cuda") - model.train() - for epoch in range(num_epoch): - progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}") - for batch in train_dataloader: - optimizer.zero_grad() - batch = {k: v.to('cuda') for k, v in batch.items()} - outputs = model(**batch) - loss = outputs.loss - loss.backward() - optimizer.step() - lr_scheduler.step() - progress_bar.update(1) - train_loss = loss - - loss = 0.0 - for batch in eval_dataloader: - batch = {k: v.to('cuda') for k, v in batch.items()} - outputs = model(**batch) - # loss = outputs.loss - assert not torch.isnan(outputs.loss), f"{batch}" - loss += outputs.loss.item() - # loss = criterion(outputs.logits, batch["input_ids"]) - test_loss = loss / len(eval_dataloader) - print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}") - if args.save_model and test_loss < best_test_loss: - best_test_loss = test_loss - torch.save(model.state_dict(), "./checkpoints/best_model.pth") - - -if __name__ == "__main__": - args = get_args() - colossalai.launch_from_torch(config=args.config) - if args.model == 'bert-base-uncased': - model = BertForMaskedLM.from_pretrained("bert-base-uncased") - elif args.model == 'gpt2': - model = GPT2LMHeadModel.from_pretrained("gpt2") - else: - raise AttributeError("model not supported") - shard_config = ShardConfig( - rank=int(str(get_current_device()).split(':')[-1]), - world_size=int(os.environ['WORLD_SIZE']), - ) - sharded_model = shard_model(model, shard_config) - - if args.mode == "train": - train(sharded_model, args) - elif args.mode == "inference": - inference(sharded_model, args) - else: - raise NotImplementedError diff --git a/colossalai/shardformer/utils/__init__.py b/colossalai/shardformer/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py deleted file mode 100644 index eb84edd88..000000000 --- a/colossalai/shardformer/utils/utils.py +++ /dev/null @@ -1,58 +0,0 @@ -def hasattr_(obj, attr: str): - r""" - Check whether the object has the multi sublevel attr - - Args: - obj (object): The object to check - attr (str): The multi level attr to check - """ - attrs = attr.split('.') - for a in attrs: - try: - obj = getattr(obj, a) - except AttributeError: - return False - return True - - -def setattr_(obj, attr: str, value, ignore: bool = False): - r""" - Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist - - Args: - obj (object): The object to set - attr (str): The multi level attr to set - value (Any): The value to set - ignore (bool): Whether to ignore when the attr doesn't exist - """ - - attrs = attr.split('.') - for a in attrs[:-1]: - try: - obj = getattr(obj, a) - except AttributeError: - if ignore: - return - raise AttributeError(f"Object {obj} has no attribute {attr}") - setattr(obj, attrs[-1], value) - - -def getattr_(obj, attr: str, ignore: bool = None): - r""" - Get the object's multi sublevel attr - - Args: - obj (object): The object to set - attr (str): The multi level attr to set - ignore (bool): Whether to ignore when the attr doesn't exist - """ - - attrs = attr.split('.') - for a in attrs: - try: - obj = getattr(obj, a) - except AttributeError: - if ignore: - return None - raise AttributeError(f"Object {obj} has no attribute {attr}") - return obj diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index dd873c852..af38d2a50 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -16,66 +16,69 @@ def _all_gather(tensor, comm_spec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() - process_group = process_groups[comm_spec.logical_process_axis] - - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) - for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) + for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) + ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor, comm_spec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() - process_group = process_groups[comm_spec.logical_process_axis] - - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) - start = length * dist.get_rank(process_group) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, _ in process_groups_list: + if dist.get_rank() in rank_list: + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + start = length * rank_list.index(dist.get_rank()) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor, comm_spec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() - process_group = process_groups[comm_spec.logical_process_axis] - world_size = dist.get_world_size(process_group) - - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size - new_shape = torch.Size(new_shape) - output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // world_size - input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) + new_shape = torch.Size(new_shape) + output_tensor_list = [ + torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) + ] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + input_tensor_list = [ + torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) + ] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor, comm_spec, async_op=False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() - process_group = process_groups[comm_spec.logical_process_axis] - - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor def _mix_gather(tensor, comm_spec): @@ -411,7 +414,7 @@ class CommSpec: self.forward_only = forward_only if isinstance(self.logical_process_axis, list): if not mix_gather: - self.device_mesh = self.sharding_spec.device_mesh.flatten() + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh self.logical_process_axis = 0 else: self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes diff --git a/colossalai/tensor/d_tensor/RAEDME.md b/colossalai/tensor/d_tensor/RAEDME.md deleted file mode 100644 index 95d866388..000000000 --- a/colossalai/tensor/d_tensor/RAEDME.md +++ /dev/null @@ -1,103 +0,0 @@ -# 🔢 Distributed Tensor - -## 📚 Table of Contents - -- [🔢 Distributed Tensor](#-distributed-tensor) - - [📚 Table of Contents](#-table-of-contents) - - [🔗 Introduction](#-introduction) - - [📝 Design](#-design) - - [🔨 Usage](#-usage) - - [🎈 Progress Log](#-progress-log) - -## 🔗 Introduction - -Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training. -It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor. - -## 📝 Design - -Our implementation is inspired by the work [Alpa](https://arxiv.org/abs/2201.12023), which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations `S` to represent the sharded dimension and `R` to represent the replicated dimension. For example, given a 2D matrix, `[S, R]` represents the tensor is sharded over the first dimension. - -Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below: - - -```text - [1, 2, 3, 4 ] -A = [4, 5, 6, 7 ] - [8, 9, 10, 11] - [12, 13, 14, 15] -``` - -`[S0, R]` would mean that the first dimension is sharded over the rows in the device topology. - -```text -| --------------------—————————————————————-| -| | | -| [1, 2, 3, 4 ] | [1, 2, 3, 4 ] | -| [4, 5, 6, 7 ] | [4, 5, 6, 7 ] | -| | | -| --------------------——————————————————----- -| | | -| [8, 9, 10, 11] | [8, 9, 10, 11] | -| [12, 13, 14, 15] | [12, 13, 14, 15] | -| | | -| --------------------——————————————————----- -``` - -`[S01, R]` would mean that the first dimension is sharded over both the row and column in the device topology. - -```text -| --------------------—————————————————————-| -| | | -| [1, 2, 3, 4 ] | [4, 5, 6, 7 ] | -| | | -| --------------------——————————————————----- -| | | -| [8, 9, 10, 11] | [12, 13, 14, 15] | -| | | -| --------------------——————————————————----- -``` - -## 🔨 Usage - -A sample API usage is given below. - -```python -import torch - -import colossalai -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.d_tensor import DTensor, ShardingSpec - -colossalai.launch_from_torch(config={}) - -# define your device mesh -# assume you have 4 GPUs -physical_mesh_id = torch.arange(0, 4).reshape(1, 4) -mesh_shape = (2, 2) -device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - -# define a tensor -a = torch.rand(16, 32).cuda() - -# create sharding spec for the tensor -# assume the sharding spec is [S0, R] -dim_partition_dict = {0: [0]} -sharding_spec = ShardingSpec(a.dim(), dim_partition_dict) - -# create a distributed tensor -d_tensor = DTensor(a, device_mesh, sharding_spec) -print(d_tensor) - -global_tensor = d_tensor.to_global() -print(global_tensor) -``` - - -## 🎈 Progress Log - -- [x] Support layout conversion -- [x] Support sharding on 2D device mesh -- [ ] Support sharding on 3D device mesh -- [ ] Support sharding 4D device mesh -- [ ] Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.) diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index af77f4f0e..e69de29bb 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -1,4 +0,0 @@ -from .d_tensor import DTensor -from .sharding_spec import ShardingSpec - -__all__ = ['DTensor', 'ShardingSpec'] diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 79b2e3ef9..159125fa1 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -24,12 +24,12 @@ class CommSpec: ''' Communication spec is used to record the communication action. It converts the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the - communication method, process_group_dict to determine the process groups, gather_dim and shard_dim + communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim to determine the buffer shape, and logical_process_axis Argument: - comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. - process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec. + comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. + process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. @@ -37,7 +37,7 @@ class CommSpec: def __init__(self, comm_pattern: CollectiveCommPattern, - process_group_dict: Dict, + process_groups_dict: Dict, gather_dim: int = None, shard_dim: int = None, logical_process_axis: int = None): @@ -45,7 +45,7 @@ class CommSpec: self.gather_dim = gather_dim self.shard_dim = shard_dim self.logical_process_axis = logical_process_axis - self.process_group_dict = process_group_dict + self.process_groups_dict = process_groups_dict def __repr__(self): res_list = ["CommSpec:("] @@ -92,56 +92,68 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] - world_size = dist.get_world_size(process_group) - tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) + ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) - start = length * dist.get_rank(process_group) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, _ in process_groups_list: + if dist.get_rank() in rank_list: + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + start = length * rank_list.index(dist.get_rank()) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] - world_size = dist.get_world_size(process_group) - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size - new_shape = torch.Size(new_shape) - output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // world_size - input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) + new_shape = torch.Size(new_shape) + output_tensor_list = [ + torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) + ] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + input_tensor_list = [ + torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) + ] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor class _ReduceGrad(torch.autograd.Function): @@ -257,7 +269,7 @@ class _AllToAll(torch.autograd.Function): def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - process_group_dict=comm_spec.process_group_dict, + process_groups_dict=comm_spec.process_groups_dict, gather_dim=comm_spec.shard_dim, shard_dim=comm_spec.gather_dim, logical_process_axis=comm_spec.logical_process_axis) diff --git a/colossalai/tensor/d_tensor/d_tensor.py b/colossalai/tensor/d_tensor/d_tensor.py index 6bda0f4e5..c1fe9d50a 100644 --- a/colossalai/tensor/d_tensor/d_tensor.py +++ b/colossalai/tensor/d_tensor/d_tensor.py @@ -3,119 +3,55 @@ from typing import Optional import torch from torch.utils._pytree import tree_map -from colossalai.device.device_mesh import DeviceMesh - from .layout import Layout from .layout_converter import LayoutConverter, to_global from .sharding_spec import ShardingSpec -__all__ = ['DTensor', 'distribute_tensor', 'distribute_module', 'construct_default_sharding_spec'] - layout_converter = LayoutConverter() class DTensor(torch.Tensor): - """ - DTensor stands for distributed tensor. It is a subclass of `torch.Tensor` and contains meta information - about the tensor distribution. The meta information includes the device mesh, the sharding specification, - and the entire shape of the tensor. - - During runtime, we will not directly use the DTensor objects for computation. Instead, we will only use the - `DTensor.local_tensor` for computation. The `DTensor.local_tensor` is the local tensor in the current rank. - In this way, all tensors involved in computation will only be native PyTorch tensors. - - Example: - ```python - from colossalai.device import DeviceMesh - - # define your device mesh - # assume you have 4 GPUs - physical_mesh_id = torch.arange(0, 4).reshape(1, 4) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - - # define a tensor - x = torch.rand(16, 32) - - # create sharding spec for the tensor - # assume the sharding spec is [S, R] - dim_partition_dict = { - 0: 1 - } - sharding_spec = ShardingSpec(a.dim(), dim_partition_dict) - - # create a distributed tensor - d_tensor = DTensor(x, device_mesh, sharding_spec) - ``` - Args: - tensor (`torch.Tensor`): the unsharded tensor. - device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices. - sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded. - """ - - def __init__(self, tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec): - # ensure this tensor is not a DTensor - assert not isinstance(tensor, DTensor), 'The input tensor should not be a DTensor.' - - # store meta info - self.local_tensor = tensor - self.data_type = tensor.dtype - self.global_shape = tensor.shape - - # create distributed layout - dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape) + def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout): + self.local_tensor = local_tensor + self.data_type = local_tensor.dtype + self.entire_shape = local_tensor.shape self.dist_layout = dist_layout - - # shard the tensor self._apply_layout() @staticmethod - def __new__(cls, tensor, *args, **kwargs): - return torch.Tensor._make_subclass(cls, tensor, tensor.requires_grad) + def __new__(cls, local_tensor, layout): + return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad) def __repr__(self): - return f"DTensor(\n{self.to_global()}\n{self.dist_layout}" + return f"DTensor({self.to_global()}, {self.dist_layout})" def __str__(self): return self.__repr__() - def layout_convert(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: + def layout_convert(self, target_layout): ''' Convert the layout of the tensor from source_spec to target_spec. - This will update the `local_tensor` and `dist_layout` in place. - - Args: - target_layout (Layout): the target layout specification. ''' - target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape) - self.local_tensor = layout_converter.apply(tensor=self.local_tensor, - source_layout=self.dist_layout, - target_layout=target_layout) + self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout) self.dist_layout = target_layout def _apply_layout(self): ''' Apply the layout to the local tensor during initializing process. ''' - # layout converter requires a source and target laytout - # we construct the source layer for an unsharded tensor - # and use self.dist_layer as the targer layout for the sharded tensor source_spec = construct_default_sharding_spec(self.local_tensor) source_layout = Layout(device_mesh=self.dist_layout.device_mesh, + device_type=self.dist_layout.device_type, sharding_spec=source_spec, - global_shape=self.global_shape) - self.local_tensor = layout_converter.apply(tensor=self.local_tensor, - source_layout=source_layout, - target_layout=self.dist_layout) + entire_shape=self.entire_shape) + self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - # convert all DTensors to native pytorch tensors - # so that operations will be conducted on native tensors def filter_arg(arg): if isinstance(arg, DTensor): return arg.local_tensor @@ -124,9 +60,9 @@ class DTensor(torch.Tensor): args = tree_map(filter_arg, args) kwargs = tree_map(filter_arg, kwargs) - - # NOTE: if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors + # if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors # and op type. + return func(*args, **kwargs) @property @@ -149,6 +85,7 @@ class DTensor(torch.Tensor): ''' self.local_tensor = self.local_tensor.to(*args, **kwargs) self.data_type = self.local_tensor.dtype + self.dist_layout.device_type = self.local_tensor.device # TODO: update the device mesh process groups or we should just cache # both the cpu process groups and the cuda process groups? return self @@ -161,7 +98,7 @@ class DTensor(torch.Tensor): def to_global(self): ''' - Recover the global tensor from the distributed tensor by returning a new `torch.Tensor` object. + Recover the global tensor from the distributed tensor. Note: This function will all_gather the local tensor to the global tensor and it will not change the layout of the DTensor. This function is mainly used for debugging or @@ -170,29 +107,24 @@ class DTensor(torch.Tensor): return to_global(self.local_tensor, self.dist_layout) -def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> DTensor: +def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor: ''' Distribute the local tensor to the distributed tensor according to the dist_layout specified. Args: - tensor (`torch.Tensor`): tensor to be distributed. - device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices. - sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded. + local_tensor: tensor to be distributed. + dist_layout: the layout specification of the distributed tensor. Returns: A 'DTensor' object. ''' - return DTensor(tensor, device_mesh, sharding_spec) + return DTensor(local_tensor, dist_layout) def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module: ''' This function converts all the parameters in the module to DTensor(DParam). - Args: - module (`torch.nn.Module`): the module to be distributed. - partition_fn (callable): the partition function which will be used to partition the parameters. - Note: This function is subject to future change as the DParam has not been implemented yet. ''' for name, param in module.named_parameters(): @@ -206,11 +138,5 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: ''' Construct the default sharding specification for the tensor. - - Args: - tensor (`torch.Tensor`): the tensor to be sharded. - - Returns: - A `ShardingSpec` object without any sharding specified. ''' return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index 2946611b4..ee7ef74a9 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -11,32 +11,28 @@ from .sharding_spec import ShardingSpec class Layout: - """ - Layout of a tensor refers to the tensor placement on the device mesh and how the tensor is sharded over the devices. + """Layout of a tensor. - Args: - device_mesh (`DeviceMesh`): the device mesh to store the tensor distributed. - sharding_spec (`ShardingSpec`): the sharding specification to describe how the tensor is sharded. - global_shape (`torch.Size`): the entire shape of the global tensor. + Attributes: + device_mesh: the device mesh to store the tensor distributed. + device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'. + sharding_spec: the sharding specification to describe how the tensor is sharded. + entire_shape: the entire shape of the global tensor. """ - def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size): + def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, + entire_shape: torch.Size): self.device_mesh = device_mesh + self.device_type = device_type self.sharding_spec = sharding_spec - self.global_shape = global_shape + self.entire_shape = entire_shape self._sanity_check() def __hash__(self) -> int: return hash(f'{self.sharding_spec}') - def get_sharded_shape_per_device(self) -> torch.Size: - """ - Compute the shape of the sharded tensor on each device. - - Returns: - `torch.Size`: the shape of the sharded tensor on each device. - """ - sharded_shape = list(self.global_shape) + def get_sharded_shape_per_device(self): + sharded_shape = list(self.entire_shape) for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) @@ -60,7 +56,7 @@ class Layout: # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): - tensor_dim_size = self.global_shape[dim] + tensor_dim_size = self.entire_shape[dim] num_devices = 1 for element in shard_list: diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 6eff92ea6..cf02aac30 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -3,8 +3,10 @@ from copy import deepcopy from dataclasses import dataclass from typing import Dict, List, Tuple +import numpy as np import torch +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout @@ -26,21 +28,13 @@ class LayoutConverterOptions: pass -def to_global(distributed_tensor: "DTensor", layout: Layout) -> torch.Tensor: - """ - Convert a distributed tensor to the global tensor with the given layout. - This function returns a native `torch.Tensor` object. - - - Args: - distributed_tensor (`DTensor`): the distributed tensor to be converted. - layout (`Layout`): the target layout specification. - """ +def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor: layout_converter = LayoutConverter() global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {}) global_layout = Layout(device_mesh=layout.device_mesh, + device_type=layout.device_type, sharding_spec=global_sharding_spec, - global_shape=layout.global_shape) + entire_shape=layout.entire_shape) with torch.no_grad(): global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout) return global_tensor @@ -55,9 +49,6 @@ def set_layout_converting_options(options: LayoutConverterOptions): class LayoutConverter(metaclass=SingletonMeta): - """ - LayoutConverter is a singleton class which converts the layout of a distributed tensor. - """ def __init__(self): self._options = None @@ -100,14 +91,15 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - global_shape = (4, 4, 4) + entire_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec, - global_shape=global_shape) + entire_shape=entire_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) for layout, comm_spec in rst_dict.items(): @@ -120,12 +112,7 @@ class LayoutConverter(metaclass=SingletonMeta): valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD source_spec = source_layout.sharding_spec - - # the key of the dict is the axis - # the value is the process group - current_rank = source_layout.device_mesh._global_rank_of_current_process - process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] - + process_groups_dict = source_layout.device_mesh.process_groups_dict for target_pair in source_spec.dim_partition_dict.items(): shard_list = all_gather_simulator(target_pair) index = target_pair[0] @@ -143,7 +130,7 @@ class LayoutConverter(metaclass=SingletonMeta): logical_process_axis = target_pair[1][-1] comm_spec = CommSpec( comm_pattern, - process_group_dict=process_group_dict, + process_groups_dict=process_groups_dict, gather_dim=gather_dim, # shard_dim will be used during backward shard_dim=gather_dim, @@ -154,7 +141,8 @@ class LayoutConverter(metaclass=SingletonMeta): new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: @@ -179,14 +167,15 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - global_shape = (4, 4, 4) + entire_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec, - global_shape=global_shape) + entire_shape=entire_shape) rst_dict = layout_converter.all_to_all_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -199,12 +188,7 @@ class LayoutConverter(metaclass=SingletonMeta): ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD - - # the key of the dict is the axis - # the value is the process group - current_rank = source_layout.device_mesh._global_rank_of_current_process - process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] - + process_groups_dict = source_layout.device_mesh.process_groups_dict source_spec = source_layout.sharding_spec tensor_dims = source_spec.dims for f_index in range(tensor_dims - 1): @@ -245,7 +229,7 @@ class LayoutConverter(metaclass=SingletonMeta): shard_dim = f_index logical_process_axis = b_target_pair[1][-1] comm_spec = CommSpec(comm_pattern, - process_group_dict=process_group_dict, + process_groups_dict, gather_dim=gather_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -268,7 +252,8 @@ class LayoutConverter(metaclass=SingletonMeta): new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -293,15 +278,16 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - global_shape = (4, 4, 4) + entire_shape = (4, 4, 4) dim_partition_dict = {0: [0]} # [S0,R,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec, - global_shape=global_shape) + entire_shape=entire_shape) rst_dict = layout_converter.shard_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -315,11 +301,7 @@ class LayoutConverter(metaclass=SingletonMeta): valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD source_spec = source_layout.sharding_spec - - # the key of the dict is the axis - # the value is the process group - current_rank = source_layout.device_mesh._global_rank_of_current_process - process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + process_groups_dict = source_layout.device_mesh.process_groups_dict # legal sharding dims means the mesh_id is still available to use. legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))] @@ -347,7 +329,7 @@ class LayoutConverter(metaclass=SingletonMeta): shard_dim = index logical_process_axis = shard_list[-1] comm_spec = CommSpec(comm_pattern, - process_group_dict=process_group_dict, + process_groups_dict, gather_dim=shard_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -358,7 +340,8 @@ class LayoutConverter(metaclass=SingletonMeta): dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -416,7 +399,7 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - global_shape = (4, 4, 4) + entire_shape = (4, 4, 4) dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} @@ -424,14 +407,16 @@ class LayoutConverter(metaclass=SingletonMeta): # [R,S01,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - global_shape=global_shape) + entire_shape=entire_shape) # [S01,R,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - global_shape=global_shape) + entire_shape=entire_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) @@ -520,19 +505,21 @@ class LayoutConverter(metaclass=SingletonMeta): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - global_shape = (4, 4, 4) + entire_shape = (4, 4, 4) # [S0,R,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - global_shape=global_shape) + entire_shape=entire_shape) # [R,S0,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - global_shape=global_shape) + entire_shape=entire_shape) if rank in (0, 1): sharded_tensor_0 = torch.zeros(2, 1) @@ -567,4 +554,3 @@ class LayoutConverter(metaclass=SingletonMeta): for comm_spec in comm_action_sequence: tensor = comm_spec.covert_spec_to_action(tensor) return tensor - return tensor diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py index 45b05e10e..565012b58 100644 --- a/colossalai/tensor/d_tensor/sharding_spec.py +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -116,21 +116,21 @@ class DimSpec: def dim_diff(self, other): ''' - The difference between two DimSpec. + The difference between two _DimSpec. Argument: - other(DimSpec): the dim spec to compare with. + other(_DimSpec): the dim spec to compare with. Return: difference(int): the difference between two _DimSpec. Example: - ```python - dim_spec = DimSpec([0]) - other_dim_spec = DimSpec([0, 1]) + dim_spec = _DimSpec([0]) + other_dim_spec = _DimSpec([0, 1]) print(dim_spec.difference(other_dim_spec)) - # output: 5 - ``` + + Output: + 5 ''' difference = self.difference_dict[(str(self), str(other))] return difference @@ -142,13 +142,9 @@ class ShardingSpec: [R, R, S0, S1], which means Argument: - dim_size (int): The number of dimensions of the tensor to be sharded. - dim_partition_dict (Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, - and the value of the key describe which logical axis will be sharded in that dimension. Defaults to None. - E.g. {0: [0, 1]} means the first dimension of the tensor will be sharded in logical axis 0 and 1. - sharding_sequence (List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. - Generally, users should specify either dim_partition_dict or sharding_sequence. - If both are given, users must ensure that they are consistent with each other. Defaults to None. + dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, + and the value of the key describe which logical axis will be sharded in that dimension. + sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. ''' def __init__(self, @@ -212,7 +208,6 @@ class ShardingSpec: pair of sharding sequence. Example: - ```python dim_partition_dict = {0: [0, 1]} # DistSpec: # shard_sequence: S01,R,R @@ -224,8 +219,10 @@ class ShardingSpec: # device_mesh_shape: (4, 4) sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare) print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare)) - # output: 25 - ``` + + Output: + 25 + Argument: other(ShardingSpec): The ShardingSpec to compared with. diff --git a/docs/sidebars.json b/docs/sidebars.json index c3cfbbeef..8be40e451 100644 --- a/docs/sidebars.json +++ b/docs/sidebars.json @@ -64,7 +64,6 @@ }, "features/pipeline_parallel", "features/nvme_offload", - "features/lazy_init", "features/cluster_utils" ] }, diff --git a/docs/source/en/features/lazy_init.md b/docs/source/en/features/lazy_init.md deleted file mode 100644 index 40f5da1cb..000000000 --- a/docs/source/en/features/lazy_init.md +++ /dev/null @@ -1,71 +0,0 @@ -# Lazy initialization - -Author: Hongxin Liu - -**Prerequisite** -- [Booster API](../basics/booster_api.md) -- [Booster Plugins](../basics/booster_plugins.md) -- [Booster Checkpoint](../basics/booster_checkpoint.md) - -**Related discussion** -- [Lazy initialization of model](https://github.com/hpcaitech/ColossalAI/discussions/3124) - -## Introduction - -LazyTensor allows DL framework (PyTorch) to execute operations lazily, by storing all operations related to it and reruning them when it's required to be materialized. - -LazyInit defers model initialization and it's based on LazyTensor. - -This is especially useful when we use model parallelism to train large models, in which case the model cannot fit in GPU memory. Through this, we can initialize model tensors using meta tensor and do static analysis to get shard strategy. And then materialize each tensor and apply the shard strategy. The static analysis can be omitted if the shard strategy is known in advance. - -## Usage - -You may use lazy initialization when using Gemini, tensor parallelism, pipeline parallelism, and auto-parallelism. In other cases, you may not need to use lazy initialization. - -Gemini is compatible with lazy initialization. You can use them together directly. - -```python -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin -from colossalai.lazy import LazyInitContext -from colossalai.nn.optimizer import HybridAdam -from torch.nn import Linear -import colossalai - -colossalai.launch_from_torch({}) - -plugin = GeminiPlugin() -booster = Booster(plugin=plugin) - -with LazyInitContext(): - model = Linear(10, 10) - -optimizer = HybridAdam(model.parameters()) -model, optimizer, *_ = booster.boost(model, optimizer) -``` - -Note that using lazy initialization when using Gemini is not necessary but recommended. If you don't use lazy initialization, you may get OOM error when initializing the model. If you use lazy initialization, you can avoid this error. - -> ⚠ Lazy initialization support for tensor parallelism, pipeline parallelism, and auto-parallelism is still under development. - -### Load from pretrained model - -We should not load pretrained weight in `LazyInitContext`. If so, lazy initialization is meaningless, as the checkpoint is loaded and it takes much GPU memory. A recommended way is to initialize model from scratch in `LazyInitContext` and load pretrained weight outside `LazyInitContext` after calling `Booster.boost()`. - - -```python -with LazyInitContext(): - model = GPT2LMHeadModel(config) - -optimizer = ... -lr_scheduler = ... -dataloader = ... -model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader) - -booster.load_model(model, pretrained_path) -``` - - -As booster supports both pytorch-fashion checkpoint and huggingface/transformers-fashion pretrained weight, the `pretrained_path` of the above pseudo-code can be either a checkpoint file path or a pretrained weight path. Note that it does not support loading pretrained weights from network. You should download the pretrained weight first and then use a local path. - - diff --git a/docs/source/zh-Hans/features/lazy_init.md b/docs/source/zh-Hans/features/lazy_init.md deleted file mode 100644 index 9a3cd90ca..000000000 --- a/docs/source/zh-Hans/features/lazy_init.md +++ /dev/null @@ -1,71 +0,0 @@ -# 惰性初始化 - -作者: Hongxin Liu - -**前置教程** -- [Booster API](../basics/booster_api.md) -- [Booster 插件](../basics/booster_plugins.md) -- [Booster Checkpoint](../basics/booster_checkpoint.md) - -**相关讨论** -- [模型的惰性初始化](https://github.com/hpcaitech/ColossalAI/discussions/3124) - -## 引言 - -LazyTensor 允许深度学习框架 (PyTorch) 延迟执行操作,方法是存储与其相关的所有操作并在需要具体化时重新运行它们。 - -LazyInit 基于 LazyTensor,并支持延迟模型初始化。 - -这在我们使用模型并行来训练大型模型时特别有用,在这种情况下模型无法容纳在 GPU 内存中。通过这个,我们可以使用 Meta 张量初始化模型张量并进行静态分析以获得分片策略。然后具体化每个张量并应用分片策略。如果事先知道分片策略,则可以省略静态分析。 - -## 用法 - -您可以在使用 Gemini、张量并行、流水线并行和自动并行时使用惰性初始化。在其他情况下,您可能不需要使用惰性初始化。 - -Gemini 与惰性初始化兼容。您可以直接将它们一起使用。 - -```python -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin -from colossalai.lazy import LazyInitContext -from colossalai.nn.optimizer import HybridAdam -from torch.nn import Linear -import colossalai - -colossalai.launch_from_torch({}) - -plugin = GeminiPlugin() -booster = Booster(plugin=plugin) - -with LazyInitContext(): - model = Linear(10, 10) - -optimizer = HybridAdam(model.parameters()) -model, optimizer, *_ = booster.boost(model, optimizer) -``` - -请注意,在使用 Gemini 时使用惰性初始化不是必需的,但建议使用。如果不使用惰性初始化,在初始化模型时可能会出现 OOM 错误。如果使用惰性初始化,则可以避免此错误。 - -> ⚠ 对张量并行、流水线并行和自动并行的惰性初始化支持仍在开发中。 - -### 从预训练模型加载 - -我们不应该在 `LazyInitContext` 中加载预训练权重。如果这样,惰性初始化就没有意义,因为检查点已加载并且需要大量 GPU 内存。推荐的方法是在 `LazyInitContext` 中初始化模型,并在调用 `Booster.boost()` 后在 `LazyInitContext` 之外加载预训练权重。 - - -```python -with LazyInitContext(): - model = GPT2LMHeadModel(config) - -optimizer = ... -lr_scheduler = ... -dataloader = ... -model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader) - -booster.load_model(model, pretrained_path) -``` - - -由于 booster 同时支持 pytorch 风格的 checkpoint 和 huggingface/transformers 风格的预训练权重,上述伪代码的 `pretrained_pa​​th` 可以是 checkpoint 文件路径或预训练权重路径。请注意,它不支持从网络加载预训练权重。您应该先下载预训练的权重,然后使用本地路径。 - - diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 19d41d233..3be057b3a 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -1,19 +1,20 @@ -import torch - from colossalai.device.device_mesh import DeviceMesh +import torch def test_device_mesh(): - physical_mesh_id = torch.arange(0, 16) + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10,11], # [12,13,14,15]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - assert device_mesh.global_rank_to_local_rank(5) == [1, 1] - assert device_mesh.global_rank_to_local_rank(11) == [2, 3] - assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3] + assert device_mesh.convert_map[5] == [1, 1] + assert device_mesh.convert_map[11] == [2, 3] + assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] + assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]] + assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] if __name__ == '__main__': diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 7c6339eff..2b7060c48 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -20,12 +20,16 @@ def check_layer(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - - for axis in range(len(mesh_shape)): - tensor = torch.ones(4).cuda() - pg = device_mesh.get_process_group(axis=axis) - dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg) - assert tensor.equal(tensor_to_check) + logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]} + logical_process_groups = device_mesh.process_groups_dict + + for mesh_dim, pgs in logical_pg_dict.items(): + for index, pg in enumerate(pgs): + if rank in pg: + tensor = torch.ones(4).cuda() + group = logical_process_groups[mesh_dim][index][1] + dist.all_reduce(tensor, op=ReduceOp.SUM, group=group) + assert tensor.equal(tensor_to_check) gpc.destroy() diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 2911012fa..85bfd0e27 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -6,9 +6,7 @@ import numpy as np import torch from packaging import version -from colossalai.device.device_mesh import DeviceMesh from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor -from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout_converter import to_global from tests.kit.model_zoo.registry import ModelAttribute @@ -83,8 +81,7 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, print(f'{model.__class__.__name__} pass') -def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, - sharding_spec_dict: dict) -> None: +def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: state = model.state_dict() distributed_state = distributed_model.state_dict() @@ -94,7 +91,6 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn. assert n1 == n2 t1 = t1.cuda() t2 = t2.cuda() - if n2 in sharding_spec_dict: - layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape) - t2 = to_global(t2, layout) + if n2 in layout_dict: + t2 = to_global(t2, layout_dict[n2]) assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py index efa43eab5..d515b175a 100644 --- a/tests/test_lazy/test_distribute.py +++ b/tests/test_lazy/test_distribute.py @@ -26,19 +26,23 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]: return dim -def make_sharding_spec(original_tensor: torch.Tensor) -> Layout: +def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: shard_dim = find_shard_dim(original_tensor.shape) dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) - return target_sharding_spec + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=target_sharding_spec, + entire_shape=original_tensor.shape) + return layout def _get_current_name(prefix: str, name: str) -> str: return f'{prefix}.{name}'.lstrip('.') -def generate_sharding_spec_dict(model: nn.Module) -> dict: - sharding_spec_dict = {} +def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: + layout_dict = {} @torch.no_grad() def generate_recursively(module: nn.Module, prefix: str = ''): @@ -49,17 +53,17 @@ def generate_sharding_spec_dict(model: nn.Module) -> dict: # initialize tensors directly attached to the current module for name, param in module.named_parameters(recurse=False): if isinstance(param, LazyTensor): - sharding_spec = make_sharding_spec(param) - sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec + layout = make_layout(device_mesh, param) + layout_dict[_get_current_name(prefix, name)] = layout for name, buf in module.named_buffers(recurse=False): if isinstance(buf, LazyTensor): - sharding_spec = make_sharding_spec(buf) - sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec + layout = make_layout(device_mesh, buf) + layout_dict[_get_current_name(prefix, name)] = layout generate_recursively(model) - return sharding_spec_dict + return layout_dict @parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) @@ -81,9 +85,9 @@ def run_dist_lazy_init(subset, seed: int = 42): ctx = LazyInitContext() with ctx: deferred_model = model_fn() - sharding_spec_dict = generate_sharding_spec_dict(deferred_model) - ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True) - assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict) + layout_dict = generate_layout_dict(deferred_model, device_mesh) + ctx.distribute(deferred_model, layout_dict, verbose=True) + assert_dist_model_equal(model, deferred_model, layout_dict) def run_dist(rank, world_size, port) -> None: diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index 0797e01e7..d1f5b9299 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -125,6 +125,23 @@ def check_all_reduce_bwd(process_groups_dict, rank): assert tensor_to_comm.equal(tensor_to_check) +def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + # reduce through logical process axis 0 at flatten device mesh + # tensor to check + # tensor([[6., 6.], + # [6., 6.]]) + tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() + + # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) + comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + def check_comm(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -136,22 +153,24 @@ def check_comm(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - - process_group_dict = device_mesh._process_group_dict[rank] + process_groups_dict = device_mesh.process_groups_dict # test all gather - check_all_gather(process_group_dict, rank) + check_all_gather(process_groups_dict, rank) # test shard - check_shard(process_group_dict, rank) + check_shard(process_groups_dict, rank) # test all to all - check_all_to_all(process_group_dict, rank) + check_all_to_all(process_groups_dict, rank) # test all reduce - check_all_reduce_fwd(process_group_dict, rank) - check_all_reduce_bwd(process_group_dict, rank) + check_all_reduce_fwd(process_groups_dict, rank) + check_all_reduce_bwd(process_groups_dict, rank) + flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict + # test all reduce in 1D flatten device mesh + check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank) gpc.destroy() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 50a3bfb15..3ca369acb 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -31,9 +31,13 @@ def check_dtensor(rank, world_size, port): device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) - d_tensor = DTensor(original_tensor, device_mesh, target_sharding_spec) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=target_sharding_spec, + entire_shape=original_tensor.shape) + d_tensor = DTensor(original_tensor, layout) - assert d_tensor.global_shape == original_tensor.shape + assert d_tensor.entire_shape == original_tensor.shape assert d_tensor.data_type == original_tensor.dtype if rank in (0, 1): @@ -53,7 +57,12 @@ def check_dtensor(rank, world_size, port): raise ValueError(f'rank {rank} is not in the device mesh') new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) - d_tensor.layout_convert(device_mesh, new_sharding_spec) + new_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=new_sharding_spec, + entire_shape=original_tensor.shape) + + d_tensor.layout_convert(new_layout) if rank == 0: assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) @@ -66,7 +75,7 @@ def check_dtensor(rank, world_size, port): else: raise ValueError(f'rank {rank} is not in the device mesh') - dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec) + dtensor_from_local = distribute_tensor(original_tensor, new_layout) if rank == 0: assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 6608e4787..5f56decb5 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn -global_shape = torch.Size((64, 32, 16)) +entire_shape = torch.Size((64, 32, 16)) layout_converter = LayoutConverter() -physical_mesh_id = torch.arange(0, 4) +physical_mesh_id = torch.arange(0, 4).reshape(2, 2) mesh_shape = (2, 2) @@ -30,7 +30,10 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (2, 2) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec, + entire_shape=entire_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) @@ -46,7 +49,10 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) - layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape) + layout_all2all = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_all2all, + entire_shape=entire_shape) rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) @@ -65,7 +71,10 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,R,R # device_mesh_shape: (4, 4) sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) - shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape) + shard_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_shard, + entire_shape=entire_shape) rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) @@ -91,13 +100,19 @@ def check_layout_converting(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) @@ -144,15 +159,21 @@ def check_layout_converting_apply(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) - original_tensor = torch.rand(global_shape).cuda() + original_tensor = torch.rand(entire_shape).cuda() # tensor_to_apply: [R, S01, R] tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 859eef051..6fe9ee292 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -1,10 +1,9 @@ +from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern import torch - +from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -physical_mesh_id = torch.arange(0, 16) +physical_mesh_id = torch.arange(0, 16).reshape(2, 8) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index 9bd9805e9..d66d4fec1 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -26,7 +26,7 @@ def run_dist(rank, world_size, port): # the mesh is in the following topo # [[0, 1], # [2, 3]] - physical_mesh_id = torch.arange(0, 4) + physical_mesh_id = torch.arange(0, 4).reshape(2, 2) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) row_id = rank // 2 diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py index 5007c4141..909c84ef0 100644 --- a/tests/test_tensor/test_sharding_spec.py +++ b/tests/test_tensor/test_sharding_spec.py @@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec def test_sharding_spec(): - physical_mesh_id = torch.arange(0, 16) + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7],