mirror of https://github.com/hpcaitech/ColossalAI
[dtensor] updated api and doc (#3845)
parent
d51e83d642
commit
eb39154d40
|
@ -0,0 +1,73 @@
|
|||
# 🗄 Device
|
||||
|
||||
## 📚 Table of Contents
|
||||
|
||||
- [🗄 Device](#-device)
|
||||
- [📚 Table of Contents](#-table-of-contents)
|
||||
- [🔗 Introduction](#-introduction)
|
||||
- [📝 Design](#-design)
|
||||
- [🔨 Usage](#-usage)
|
||||
|
||||
## 🔗 Introduction
|
||||
|
||||
This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network.
|
||||
|
||||
## 📝 Design
|
||||
|
||||
|
||||
This module is inspired by the DeviceMesh in the [Alpa project](https://github.com/alpa-projects/alpa) and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future.
|
||||
|
||||
|
||||
## 🔨 Usage
|
||||
|
||||
- Create a device mesh
|
||||
|
||||
```python
|
||||
# this is the list of global ranks involved in the device mesh
|
||||
# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3
|
||||
physical_mesh_id = torch.arange(4)
|
||||
mesh_shape = [2, 2]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
```
|
||||
|
||||
- View the mesh
|
||||
|
||||
|
||||
```python
|
||||
# view the mesh shape
|
||||
# expect output
|
||||
# [2, 2]
|
||||
print(device_mesh.shape)
|
||||
|
||||
|
||||
# view the logical mesh with global ranks
|
||||
# expect output
|
||||
# [
|
||||
# [0, 1],
|
||||
# [2, 3]
|
||||
# ]
|
||||
print(device_mesh.logical_mesh_id)
|
||||
|
||||
# view the number of devices in the mesh
|
||||
# expect output
|
||||
# 4
|
||||
print(device_mesh.num_devices)
|
||||
|
||||
```
|
||||
|
||||
- Initialize the process group
|
||||
|
||||
```python
|
||||
# intialize process group
|
||||
device_mesh.init_logical_process_group()
|
||||
|
||||
|
||||
# get the process group for a rank with respect to an axis
|
||||
# this is the process group involving global ranks 0 and 2
|
||||
print(device_mesh.get_process_group(axis=0, global_rank=0))
|
||||
|
||||
# get the ranks in the process with respect to an axis
|
||||
# expect output
|
||||
# [0, 2]
|
||||
print(device_mesh.get_ranks_in_process_group(axis=0, global_rank=0))
|
||||
```
|
|
@ -3,11 +3,19 @@
|
|||
with some changes. """
|
||||
|
||||
import operator
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessGroupContainer:
|
||||
process_group: ProcessGroup
|
||||
ranks: List[int]
|
||||
|
||||
|
||||
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
|
||||
|
@ -27,9 +35,11 @@ class DeviceMesh:
|
|||
during initializing the DeviceMesh instance if the init_process_group set to True.
|
||||
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
|
||||
(default: False)
|
||||
need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True.
|
||||
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
|
||||
"""
|
||||
|
||||
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
|
||||
|
||||
def __init__(self,
|
||||
physical_mesh_id: torch.Tensor,
|
||||
mesh_shape: torch.Size = None,
|
||||
|
@ -37,48 +47,140 @@ class DeviceMesh:
|
|||
mesh_alpha: List[float] = None,
|
||||
mesh_beta: List[float] = None,
|
||||
init_process_group: bool = False,
|
||||
need_flatten: bool = True):
|
||||
self.physical_mesh_id = physical_mesh_id
|
||||
device: str = 'cuda'):
|
||||
# ============================
|
||||
# Physical & Logical Mesh IDs
|
||||
# ============================
|
||||
self._physical_mesh_id = physical_mesh_id
|
||||
assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor."
|
||||
|
||||
# logical mesh ids can be obtained via two ways
|
||||
# 1. provide physical mesh id and provide mesh shape
|
||||
# 2. directly supply the logical mesh id
|
||||
assert mesh_shape is None or logical_mesh_id is None, \
|
||||
"Only one of mesh_shape and logical_mesh_id can be specified." \
|
||||
"Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id"
|
||||
|
||||
if logical_mesh_id is None:
|
||||
self.mesh_shape = mesh_shape
|
||||
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
|
||||
self._logical_mesh_id = self._physical_mesh_id.reshape(self.mesh_shape)
|
||||
else:
|
||||
self._logical_mesh_id = logical_mesh_id
|
||||
self.mesh_shape = self._logical_mesh_id.shape
|
||||
|
||||
# map global rank into logical rank
|
||||
self.convert_map = {}
|
||||
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
|
||||
# ensure two things:
|
||||
# 1. logical and physical mesh IDs should contain the same elements
|
||||
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
|
||||
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
|
||||
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
|
||||
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
|
||||
"Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again."
|
||||
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
|
||||
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
|
||||
|
||||
# ===============================================
|
||||
# coefficient for alpha-beta communication model
|
||||
# alpha is latency and beta is bandwidth
|
||||
# ===============================================
|
||||
# if the values are not provided, we assume they are 1 for simplicity
|
||||
if mesh_alpha is None:
|
||||
mesh_alpha = [1] * len(self.mesh_shape)
|
||||
if mesh_beta is None:
|
||||
mesh_beta = [1] * len(self.mesh_shape)
|
||||
|
||||
self.mesh_alpha = tuple(mesh_alpha)
|
||||
self.mesh_beta = tuple(mesh_beta)
|
||||
self.init_process_group = init_process_group
|
||||
self.need_flatten = need_flatten
|
||||
if self.init_process_group:
|
||||
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
|
||||
if self.need_flatten and self._logical_mesh_id.dim() > 1:
|
||||
self.flatten_device_mesh = self.flatten()
|
||||
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
|
||||
# self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
|
||||
# self.mesh_beta)
|
||||
|
||||
# ensure the alpha and beta have the same shape
|
||||
assert len(self.mesh_alpha) == len(self.mesh_beta), \
|
||||
"mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
|
||||
|
||||
# =========================
|
||||
# Device for Process Group
|
||||
# =========================
|
||||
self._device = device
|
||||
self._dist_backend = self._DIST_BACKEND[device]
|
||||
|
||||
# =========================
|
||||
# Process Group Management
|
||||
# =========================
|
||||
# the _global_to_local_rank_mapping is structured as follows
|
||||
# {
|
||||
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
|
||||
# }
|
||||
self._global_to_local_rank_mapping = dict()
|
||||
self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
|
||||
tensor=self.logical_mesh_id)
|
||||
|
||||
# create process group
|
||||
self._process_group_dict = {}
|
||||
self._ranks_in_the_process_group = {}
|
||||
self._global_rank_of_current_process = None
|
||||
self._is_initialized = False
|
||||
|
||||
# initialize process group if specified
|
||||
self._init_ranks_in_the_same_group()
|
||||
self._init_process_group = init_process_group
|
||||
if init_process_group:
|
||||
self.init_logical_process_group()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
def shape(self) -> torch.Size:
|
||||
"""
|
||||
Return the shape of the logical mesh.
|
||||
"""
|
||||
return self.mesh_shape
|
||||
|
||||
@property
|
||||
def num_devices(self):
|
||||
return reduce(operator.mul, self.physical_mesh_id.shape, 1)
|
||||
def num_devices(self) -> int:
|
||||
"""
|
||||
Return the number of devices contained in the device mesh.
|
||||
"""
|
||||
return reduce(operator.mul, self._physical_mesh_id.shape, 1)
|
||||
|
||||
@property
|
||||
def logical_mesh_id(self):
|
||||
def logical_mesh_id(self) -> torch.Tensor:
|
||||
"""
|
||||
Return the logical mesh id.
|
||||
"""
|
||||
return self._logical_mesh_id
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
|
||||
"""
|
||||
Return the process group on the specified axis.
|
||||
|
||||
Args:
|
||||
axis (int): the axis of the process group.
|
||||
global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None)
|
||||
"""
|
||||
if global_rank is None:
|
||||
global_rank = self._global_rank_of_current_process
|
||||
return self._process_group_dict[global_rank][axis]
|
||||
|
||||
def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
|
||||
"""
|
||||
Return the process groups for all axes.
|
||||
|
||||
Args:
|
||||
global_rank (int, optional): the global rank of the process
|
||||
"""
|
||||
if global_rank is None:
|
||||
global_rank = self._global_rank_of_current_process
|
||||
return self._process_group_dict[global_rank]
|
||||
|
||||
def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
|
||||
"""
|
||||
Return the ranks in the process group on the specified axis.
|
||||
|
||||
Args:
|
||||
axis (int): the axis of the process group.
|
||||
global_rank (int, optional): the global rank of the process
|
||||
"""
|
||||
if global_rank is None:
|
||||
global_rank = self._global_rank_of_current_process
|
||||
return self._ranks_in_the_process_group[global_rank][axis]
|
||||
|
||||
def __deepcopy__(self, memo) -> "DeviceMesh":
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls)
|
||||
memo[id(self)] = result
|
||||
|
@ -86,111 +188,206 @@ class DeviceMesh:
|
|||
if k != 'process_groups_dict':
|
||||
setattr(result, k, __import__("copy").deepcopy(v, memo))
|
||||
else:
|
||||
# process group cannot be copied
|
||||
# thus, we share them directly
|
||||
setattr(result, k, v)
|
||||
|
||||
return result
|
||||
|
||||
def _init_global_to_logical_rank_mapping(self,
|
||||
mapping: Dict,
|
||||
tensor: torch.Tensor,
|
||||
index_list: List[int] = []) -> Dict[int, List[int]]:
|
||||
"""
|
||||
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
|
||||
|
||||
Args:
|
||||
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
|
||||
tensor (torch.Tensor): the tensor that contains the logical mesh ids.
|
||||
index_list (List[int])
|
||||
|
||||
Returns:
|
||||
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
|
||||
The value is a list of integers and each integer represents the local rank in the indexed axis.
|
||||
"""
|
||||
for index, inner_tensor in enumerate(tensor):
|
||||
# index means the local rank in the current axis
|
||||
# inner_tensor refers to the processes with the same local rank
|
||||
|
||||
if inner_tensor.numel() == 1:
|
||||
# if the inner_tensor only has one element, it means that
|
||||
# it already reaches the last axis
|
||||
# we append its local_rank in the last axis to the index_list
|
||||
# and assign to the mapping
|
||||
# the value of the mapping is the the local rank at the indexed axis of the device mesh
|
||||
mapping[int(inner_tensor)] = index_list + [index]
|
||||
else:
|
||||
# we recursively go into the function until we reach the last axis
|
||||
# meanwhile, we should add the local rank in the current axis in the index_list
|
||||
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
|
||||
|
||||
def init_logical_process_group(self):
|
||||
'''
|
||||
This method is used to initialize the logical process groups which will be used in communications
|
||||
among logical device mesh.
|
||||
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
|
||||
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
|
||||
'''
|
||||
# sanity check
|
||||
assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group"
|
||||
assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice"
|
||||
|
||||
# update the global rank of the current process
|
||||
self._global_rank_of_current_process = dist.get_rank()
|
||||
duplicate_check_list = []
|
||||
|
||||
# flatten the global ranks to 1D list
|
||||
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
|
||||
|
||||
for global_rank in global_rank_flatten_list:
|
||||
# find the other ranks which are in the same process group as global_rank
|
||||
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
|
||||
|
||||
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
|
||||
# skip duplicated process group creation
|
||||
if ranks_in_same_group in duplicate_check_list:
|
||||
continue
|
||||
|
||||
# create the process group
|
||||
pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend)
|
||||
|
||||
# keep this process group in the process_groups_dict
|
||||
for rank in ranks_in_same_group:
|
||||
if rank not in self._process_group_dict:
|
||||
self._process_group_dict[rank] = dict()
|
||||
self._process_group_dict[rank][axis] = pg_handler
|
||||
|
||||
# update the init flag
|
||||
# we only allow init for once
|
||||
self._is_initialized = True
|
||||
|
||||
def _init_ranks_in_the_same_group(self):
|
||||
"""
|
||||
This method is used to initialize the ranks_in_the_same_group dictionary.
|
||||
"""
|
||||
# flatten the global ranks to 1D list
|
||||
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
|
||||
|
||||
for global_rank in global_rank_flatten_list:
|
||||
# find the other ranks which are in the same process group as global_rank
|
||||
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
|
||||
|
||||
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
|
||||
# create dict for each rank
|
||||
if global_rank not in self._process_group_dict:
|
||||
self._ranks_in_the_process_group[global_rank] = dict()
|
||||
|
||||
# keep this process group in the process_groups_dict
|
||||
self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group
|
||||
|
||||
def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]:
|
||||
"""
|
||||
Return the local rank of the given global rank in the logical device mesh.
|
||||
|
||||
Args:
|
||||
rank (int): the global rank in the logical device mesh.
|
||||
axis (int): the axis of the logical device mesh.
|
||||
"""
|
||||
local_ranks = self._global_to_local_rank_mapping[rank]
|
||||
if axis:
|
||||
return local_ranks[axis]
|
||||
else:
|
||||
return local_ranks
|
||||
|
||||
def _collate_global_ranks_in_same_process_group(self, global_rank):
|
||||
'''
|
||||
Give a global rank and return all global ranks involved in its associated process group in each axis.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
sphysical_mesh_id = torch.arange(0, 16)
|
||||
mesh_shape = (4, 4)
|
||||
|
||||
# logical mesh will look like
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
print(device_mesh.collate_global_ranks_in_same_process_group(0))
|
||||
|
||||
# key is axis name
|
||||
# value is a list of global ranks in same axis with rank 0
|
||||
# output will look like
|
||||
# {
|
||||
0: [0, 4, 8, 12],
|
||||
1: [0, 1, 2, 3]
|
||||
# }
|
||||
'''
|
||||
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
|
||||
# for self._global_to_local_rank_mapping
|
||||
# the key is the global rank
|
||||
# the value is the list of local ranks corresponding to the global rank with respect of different axes
|
||||
# we can see the list of local ranks as the process coordinates for simplicity
|
||||
# the key and value are all unique, therefore,
|
||||
# we can also to use the coordinates to find the global rank
|
||||
|
||||
# =========================================================================
|
||||
# Step 1
|
||||
# find all the process_coordinates for processes in the same process group
|
||||
# as the given global rank
|
||||
# =========================================================================
|
||||
|
||||
# each
|
||||
processes_in_the_same_process_group = {}
|
||||
|
||||
for dim in range(self.logical_mesh_id.dim()):
|
||||
# iterate over the dimension size so that we can include all processes
|
||||
# in the same process group in the given axis
|
||||
# the _local_rank refers to the local rank of the current process
|
||||
for _local_rank in range(self.logical_mesh_id.shape[dim]):
|
||||
|
||||
# if this dimension is not initailized yet,
|
||||
# initialize it with an empty array
|
||||
if dim not in processes_in_the_same_process_group:
|
||||
processes_in_the_same_process_group[dim] = []
|
||||
|
||||
# get the local rank corresponding to the global rank
|
||||
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
|
||||
|
||||
# replace the local rank in the given dimension with the
|
||||
# lcoal rank of the current process iterated
|
||||
process_coordinates[dim] = _local_rank
|
||||
processes_in_the_same_process_group[dim].append(process_coordinates)
|
||||
|
||||
# =================================================================
|
||||
# Step 2
|
||||
# Use local rank combination to find its corresponding global rank
|
||||
# =================================================================
|
||||
# the key of the dict is the axis
|
||||
# the value is the list of global ranks which are in the same process group as the given global rank
|
||||
global_pg_ranks = {}
|
||||
for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items():
|
||||
global_pg_ranks[dim] = []
|
||||
for process_coordinates in coordinates_of_all_processes:
|
||||
# find the global rank by local rank combination
|
||||
for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items():
|
||||
if process_coordinates == _process_coordinates:
|
||||
global_pg_ranks[dim].append(_global_rank)
|
||||
return global_pg_ranks
|
||||
|
||||
def flatten(self):
|
||||
"""
|
||||
Flatten the logical mesh into an effective 1d logical mesh,
|
||||
"""
|
||||
flatten_mesh_shape_size = len(self.mesh_shape)
|
||||
flatten_mesh_shape = [self.num_devices]
|
||||
return DeviceMesh(self.physical_mesh_id,
|
||||
return DeviceMesh(self._physical_mesh_id,
|
||||
tuple(flatten_mesh_shape),
|
||||
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
|
||||
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
|
||||
init_process_group=self.init_process_group,
|
||||
need_flatten=False)
|
||||
|
||||
def _global_rank_to_logical_rank_map(self, tensor, index_list):
|
||||
'''
|
||||
This method is a helper function to build convert_map recursively.
|
||||
'''
|
||||
for index, inner_tensor in enumerate(tensor):
|
||||
if inner_tensor.numel() == 1:
|
||||
self.convert_map[int(inner_tensor)] = index_list + [index]
|
||||
else:
|
||||
self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
|
||||
|
||||
def create_process_groups_for_logical_mesh(self):
|
||||
'''
|
||||
This method is used to initialize the logical process groups which will be used in communications
|
||||
among logical device mesh.
|
||||
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
|
||||
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
|
||||
'''
|
||||
process_groups_dict = {}
|
||||
check_duplicate_list = []
|
||||
global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
|
||||
for global_rank in global_rank_flatten_list:
|
||||
process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
|
||||
for axis, process_group in process_groups.items():
|
||||
if axis not in process_groups_dict:
|
||||
process_groups_dict[axis] = []
|
||||
if process_group not in check_duplicate_list:
|
||||
check_duplicate_list.append(process_group)
|
||||
process_group_handler = dist.new_group(process_group)
|
||||
process_groups_dict[axis].append((process_group, process_group_handler))
|
||||
|
||||
return process_groups_dict
|
||||
|
||||
def global_rank_to_logical_rank(self, rank):
|
||||
return self.convert_map[rank]
|
||||
|
||||
def global_rank_to_process_groups_with_logical_rank(self, rank):
|
||||
'''
|
||||
Give a global rank and return all logical process groups of this rank.
|
||||
for example:
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
print(device_mesh.global_rank_to_process_groups_with_logical_rank(0))
|
||||
output:
|
||||
# key is axis name
|
||||
# value is a list of logical ranks in same axis with rank 0
|
||||
{0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]}
|
||||
'''
|
||||
process_groups = {}
|
||||
for d in range(self.logical_mesh_id.dim()):
|
||||
for replacer in range(self.logical_mesh_id.shape[d]):
|
||||
if d not in process_groups:
|
||||
process_groups[d] = []
|
||||
process_group_member = self.convert_map[rank].copy()
|
||||
process_group_member[d] = replacer
|
||||
process_groups[d].append(process_group_member)
|
||||
return process_groups
|
||||
|
||||
def global_rank_to_process_groups_with_global_rank(self, rank):
|
||||
'''
|
||||
Give a global rank and return all process groups of this rank.
|
||||
for example:
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
print(device_mesh.global_rank_to_process_groups_with_global_rank(0))
|
||||
output:
|
||||
# key is axis name
|
||||
# value is a list of global ranks in same axis with rank 0
|
||||
{0: [0, 4, 8, 12], 1: [0, 1, 2, 3]}
|
||||
'''
|
||||
logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank)
|
||||
process_groups = {}
|
||||
for dim, logical_ranks in logical_process_groups.items():
|
||||
process_groups[dim] = []
|
||||
for logical_rank in logical_ranks:
|
||||
for g_rank, l_rank in self.convert_map.items():
|
||||
if l_rank == logical_rank:
|
||||
process_groups[dim].append(g_rank)
|
||||
return process_groups
|
||||
init_process_group=self._init_process_group)
|
||||
|
||||
def all_gather_cost(self, num_bytes, mesh_dim):
|
||||
num_devices = self.logical_mesh_id.shape[mesh_dim]
|
||||
|
@ -212,38 +409,3 @@ class DeviceMesh:
|
|||
penalty_factor = num_devices / 2.0
|
||||
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
|
||||
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
|
||||
|
||||
|
||||
class FlattenDeviceMesh(DeviceMesh):
|
||||
|
||||
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
|
||||
super().__init__(physical_mesh_id,
|
||||
mesh_shape,
|
||||
mesh_alpha,
|
||||
mesh_beta,
|
||||
init_process_group=False,
|
||||
need_flatten=False)
|
||||
# Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
|
||||
self.mesh_alpha = max(self.mesh_alpha)
|
||||
self.mesh_beta = min(self.mesh_beta)
|
||||
# Different from original process_groups_dict, rank_list is not stored
|
||||
self.process_number_dict = self.create_process_numbers_for_logical_mesh()
|
||||
|
||||
def create_process_numbers_for_logical_mesh(self):
|
||||
'''
|
||||
Build 1d DeviceMesh in column-major(0) and row-major(1)
|
||||
for example:
|
||||
mesh_shape = (2,4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7]]
|
||||
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
|
||||
'''
|
||||
num_devices = reduce(operator.mul, self.mesh_shape, 1)
|
||||
process_numbers_dict = {}
|
||||
process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
|
||||
process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
|
||||
return process_numbers_dict
|
||||
|
||||
def mix_gather_cost(self, num_bytes):
|
||||
num_devices = reduce(operator.mul, self.mesh_shape, 1)
|
||||
return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from types import MethodType
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -8,8 +8,9 @@ from torch import Tensor
|
|||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_NORMAL_FACTORY = [
|
||||
|
@ -172,7 +173,7 @@ class LazyTensor(torch.Tensor):
|
|||
self.clean()
|
||||
return _convert_cls(self, target)
|
||||
|
||||
def distribute(self, layout: Layout) -> torch.Tensor:
|
||||
def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
|
||||
|
||||
Args:
|
||||
|
@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor):
|
|||
"""
|
||||
target = self._materialize_data()
|
||||
self.clean()
|
||||
local_tensor = DTensor(target, layout).local_tensor
|
||||
local_tensor = DTensor(target, device_mesh, sharding_spec).local_tensor
|
||||
return _convert_cls(self, local_tensor)
|
||||
|
||||
def clean(self) -> None:
|
||||
|
@ -536,7 +537,10 @@ class LazyInitContext:
|
|||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
@staticmethod
|
||||
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
|
||||
def distribute(module: nn.Module,
|
||||
device_mesh: DeviceMesh,
|
||||
sharding_spec_dict: Dict[str, ShardingSpec],
|
||||
verbose: bool = False) -> nn.Module:
|
||||
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
|
||||
Args:
|
||||
|
@ -546,7 +550,7 @@ class LazyInitContext:
|
|||
"""
|
||||
|
||||
def apply_fn(name: str, p: LazyTensor):
|
||||
p.distribute(layout_dict[name])
|
||||
p.distribute(device_mesh, sharding_spec_dict[name])
|
||||
|
||||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
|
|
|
@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec):
|
|||
'''
|
||||
Implement all gather operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
||||
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
|
||||
]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
||||
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
|
||||
]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _split(tensor, comm_spec):
|
||||
'''
|
||||
Implement shard operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, _ in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
start = length * rank_list.index(dist.get_rank())
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
||||
start = length * dist.get_rank(process_group)
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_to_all(tensor, comm_spec):
|
||||
'''
|
||||
Implement all to all operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [
|
||||
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
input_tensor_list = [
|
||||
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
|
||||
]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // world_size
|
||||
input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_reduce(tensor, comm_spec, async_op=False):
|
||||
'''
|
||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
|
||||
|
||||
def _mix_gather(tensor, comm_spec):
|
||||
|
@ -414,7 +411,7 @@ class CommSpec:
|
|||
self.forward_only = forward_only
|
||||
if isinstance(self.logical_process_axis, list):
|
||||
if not mix_gather:
|
||||
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
|
||||
self.device_mesh = self.sharding_spec.device_mesh.flatten()
|
||||
self.logical_process_axis = 0
|
||||
else:
|
||||
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
# 🔢 Distributed Tensor
|
||||
|
||||
## 📚 Table of Contents
|
||||
|
||||
- [🔢 Distributed Tensor](#-distributed-tensor)
|
||||
- [📚 Table of Contents](#-table-of-contents)
|
||||
- [🔗 Introduction](#-introduction)
|
||||
- [📝 Design](#-design)
|
||||
- [🔨 Usage](#-usage)
|
||||
- [🎈 Progress Log](#-progress-log)
|
||||
|
||||
## 🔗 Introduction
|
||||
|
||||
Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training.
|
||||
It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor.
|
||||
|
||||
## 📝 Design
|
||||
|
||||
Our implementation is inspired by the work [Alpa](https://arxiv.org/abs/2201.12023), which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations `S` to represent the sharded dimension and `R` to represent the replicated dimension. For example, given a 2D matrix, `[S, R]` represents the tensor is sharded over the first dimension.
|
||||
|
||||
Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below:
|
||||
|
||||
|
||||
```text
|
||||
[1, 2, 3, 4 ]
|
||||
A = [4, 5, 6, 7 ]
|
||||
[8, 9, 10, 11]
|
||||
[12, 13, 14, 15]
|
||||
```
|
||||
|
||||
`[S0, R]` would mean that the first dimension is sharded over the rows in the device topology.
|
||||
|
||||
```text
|
||||
| --------------------—————————————————————-|
|
||||
| | |
|
||||
| [1, 2, 3, 4 ] | [1, 2, 3, 4 ] |
|
||||
| [4, 5, 6, 7 ] | [4, 5, 6, 7 ] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
| | |
|
||||
| [8, 9, 10, 11] | [8, 9, 10, 11] |
|
||||
| [12, 13, 14, 15] | [12, 13, 14, 15] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
```
|
||||
|
||||
`[S01, R]` would mean that the first dimension is sharded over both the row and column in the device topology.
|
||||
|
||||
```text
|
||||
| --------------------—————————————————————-|
|
||||
| | |
|
||||
| [1, 2, 3, 4 ] | [4, 5, 6, 7 ] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
| | |
|
||||
| [8, 9, 10, 11] | [12, 13, 14, 15] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
```
|
||||
|
||||
## 🔨 Usage
|
||||
|
||||
A sample API usage is given below.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor import DTensor, ShardingSpec
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
# define your device mesh
|
||||
# assume you have 4 GPUs
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(1, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
# define a tensor
|
||||
a = torch.rand(16, 32).cuda()
|
||||
|
||||
# create sharding spec for the tensor
|
||||
# assume the sharding spec is [S0, R]
|
||||
dim_partition_dict = {0: [0]}
|
||||
sharding_spec = ShardingSpec(a.dim(), dim_partition_dict)
|
||||
|
||||
# create a distributed tensor
|
||||
d_tensor = DTensor(a, device_mesh, sharding_spec)
|
||||
print(d_tensor)
|
||||
|
||||
global_tensor = d_tensor.to_global()
|
||||
print(global_tensor)
|
||||
```
|
||||
|
||||
|
||||
## 🎈 Progress Log
|
||||
|
||||
- [x] Support layout conversion
|
||||
- [x] Support sharding on 2D device mesh
|
||||
- [ ] Support sharding on 3D device mesh
|
||||
- [ ] Support sharding 4D device mesh
|
||||
- [ ] Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.)
|
|
@ -0,0 +1,4 @@
|
|||
from .d_tensor import DTensor
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['DTensor', 'ShardingSpec']
|
|
@ -24,12 +24,12 @@ class CommSpec:
|
|||
'''
|
||||
Communication spec is used to record the communication action. It converts the communication spec
|
||||
to real action which will be used in runtime. It contains comm_pattern to determine the
|
||||
communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim
|
||||
communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
|
||||
to determine the buffer shape, and logical_process_axis
|
||||
|
||||
Argument:
|
||||
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
|
||||
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
||||
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
|
||||
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
||||
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
|
||||
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
|
||||
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
|
||||
|
@ -37,7 +37,7 @@ class CommSpec:
|
|||
|
||||
def __init__(self,
|
||||
comm_pattern: CollectiveCommPattern,
|
||||
process_groups_dict: Dict,
|
||||
process_group_dict: Dict,
|
||||
gather_dim: int = None,
|
||||
shard_dim: int = None,
|
||||
logical_process_axis: int = None):
|
||||
|
@ -45,7 +45,7 @@ class CommSpec:
|
|||
self.gather_dim = gather_dim
|
||||
self.shard_dim = shard_dim
|
||||
self.logical_process_axis = logical_process_axis
|
||||
self.process_groups_dict = process_groups_dict
|
||||
self.process_group_dict = process_group_dict
|
||||
|
||||
def __repr__(self):
|
||||
res_list = ["CommSpec:("]
|
||||
|
@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
|
|||
'''
|
||||
Implement all gather operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _split(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
Implement shard operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, _ in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
start = length * rank_list.index(dist.get_rank())
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
||||
start = length * dist.get_rank(process_group)
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
Implement all to all operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [
|
||||
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
input_tensor_list = [
|
||||
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
|
||||
]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // world_size
|
||||
input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
|
||||
'''
|
||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
|
||||
|
||||
class _ReduceGrad(torch.autograd.Function):
|
||||
|
@ -269,7 +257,7 @@ class _AllToAll(torch.autograd.Function):
|
|||
def forward(ctx, input_, comm_spec):
|
||||
output = _all_to_all(input_, comm_spec)
|
||||
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
|
||||
process_groups_dict=comm_spec.process_groups_dict,
|
||||
process_group_dict=comm_spec.process_group_dict,
|
||||
gather_dim=comm_spec.shard_dim,
|
||||
shard_dim=comm_spec.gather_dim,
|
||||
logical_process_axis=comm_spec.logical_process_axis)
|
||||
|
|
|
@ -3,55 +3,119 @@ from typing import Optional
|
|||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .layout import Layout
|
||||
from .layout_converter import LayoutConverter, to_global
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['DTensor', 'distribute_tensor', 'distribute_module', 'construct_default_sharding_spec']
|
||||
|
||||
layout_converter = LayoutConverter()
|
||||
|
||||
|
||||
class DTensor(torch.Tensor):
|
||||
"""
|
||||
DTensor stands for distributed tensor. It is a subclass of `torch.Tensor` and contains meta information
|
||||
about the tensor distribution. The meta information includes the device mesh, the sharding specification,
|
||||
and the entire shape of the tensor.
|
||||
|
||||
def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout):
|
||||
self.local_tensor = local_tensor
|
||||
self.data_type = local_tensor.dtype
|
||||
self.entire_shape = local_tensor.shape
|
||||
During runtime, we will not directly use the DTensor objects for computation. Instead, we will only use the
|
||||
`DTensor.local_tensor` for computation. The `DTensor.local_tensor` is the local tensor in the current rank.
|
||||
In this way, all tensors involved in computation will only be native PyTorch tensors.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from colossalai.device import DeviceMesh
|
||||
|
||||
# define your device mesh
|
||||
# assume you have 4 GPUs
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(1, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
||||
# define a tensor
|
||||
x = torch.rand(16, 32)
|
||||
|
||||
# create sharding spec for the tensor
|
||||
# assume the sharding spec is [S, R]
|
||||
dim_partition_dict = {
|
||||
0: 1
|
||||
}
|
||||
sharding_spec = ShardingSpec(a.dim(), dim_partition_dict)
|
||||
|
||||
# create a distributed tensor
|
||||
d_tensor = DTensor(x, device_mesh, sharding_spec)
|
||||
```
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): the unsharded tensor.
|
||||
device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices.
|
||||
sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec):
|
||||
# ensure this tensor is not a DTensor
|
||||
assert not isinstance(tensor, DTensor), 'The input tensor should not be a DTensor.'
|
||||
|
||||
# store meta info
|
||||
self.local_tensor = tensor
|
||||
self.data_type = tensor.dtype
|
||||
self.global_shape = tensor.shape
|
||||
|
||||
# create distributed layout
|
||||
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape)
|
||||
self.dist_layout = dist_layout
|
||||
|
||||
# shard the tensor
|
||||
self._apply_layout()
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, local_tensor, layout):
|
||||
return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad)
|
||||
def __new__(cls, tensor, *args, **kwargs):
|
||||
return torch.Tensor._make_subclass(cls, tensor, tensor.requires_grad)
|
||||
|
||||
def __repr__(self):
|
||||
return f"DTensor({self.to_global()}, {self.dist_layout})"
|
||||
return f"DTensor(\n{self.to_global()}\n{self.dist_layout}"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def layout_convert(self, target_layout):
|
||||
def layout_convert(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
||||
'''
|
||||
Convert the layout of the tensor from source_spec to target_spec.
|
||||
This will update the `local_tensor` and `dist_layout` in place.
|
||||
|
||||
Args:
|
||||
target_layout (Layout): the target layout specification.
|
||||
'''
|
||||
self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout)
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape)
|
||||
self.local_tensor = layout_converter.apply(tensor=self.local_tensor,
|
||||
source_layout=self.dist_layout,
|
||||
target_layout=target_layout)
|
||||
self.dist_layout = target_layout
|
||||
|
||||
def _apply_layout(self):
|
||||
'''
|
||||
Apply the layout to the local tensor during initializing process.
|
||||
'''
|
||||
# layout converter requires a source and target laytout
|
||||
# we construct the source layer for an unsharded tensor
|
||||
# and use self.dist_layer as the targer layout for the sharded tensor
|
||||
source_spec = construct_default_sharding_spec(self.local_tensor)
|
||||
source_layout = Layout(device_mesh=self.dist_layout.device_mesh,
|
||||
device_type=self.dist_layout.device_type,
|
||||
sharding_spec=source_spec,
|
||||
entire_shape=self.entire_shape)
|
||||
self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout)
|
||||
global_shape=self.global_shape)
|
||||
self.local_tensor = layout_converter.apply(tensor=self.local_tensor,
|
||||
source_layout=source_layout,
|
||||
target_layout=self.dist_layout)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# convert all DTensors to native pytorch tensors
|
||||
# so that operations will be conducted on native tensors
|
||||
def filter_arg(arg):
|
||||
if isinstance(arg, DTensor):
|
||||
return arg.local_tensor
|
||||
|
@ -60,9 +124,9 @@ class DTensor(torch.Tensor):
|
|||
|
||||
args = tree_map(filter_arg, args)
|
||||
kwargs = tree_map(filter_arg, kwargs)
|
||||
# if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
|
||||
# and op type.
|
||||
|
||||
# NOTE: if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
|
||||
# and op type.
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@property
|
||||
|
@ -85,7 +149,6 @@ class DTensor(torch.Tensor):
|
|||
'''
|
||||
self.local_tensor = self.local_tensor.to(*args, **kwargs)
|
||||
self.data_type = self.local_tensor.dtype
|
||||
self.dist_layout.device_type = self.local_tensor.device
|
||||
# TODO: update the device mesh process groups or we should just cache
|
||||
# both the cpu process groups and the cuda process groups?
|
||||
return self
|
||||
|
@ -98,7 +161,7 @@ class DTensor(torch.Tensor):
|
|||
|
||||
def to_global(self):
|
||||
'''
|
||||
Recover the global tensor from the distributed tensor.
|
||||
Recover the global tensor from the distributed tensor by returning a new `torch.Tensor` object.
|
||||
|
||||
Note: This function will all_gather the local tensor to the global tensor and it
|
||||
will not change the layout of the DTensor. This function is mainly used for debugging or
|
||||
|
@ -107,24 +170,29 @@ class DTensor(torch.Tensor):
|
|||
return to_global(self.local_tensor, self.dist_layout)
|
||||
|
||||
|
||||
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
|
||||
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> DTensor:
|
||||
'''
|
||||
Distribute the local tensor to the distributed tensor according to the dist_layout specified.
|
||||
|
||||
Args:
|
||||
local_tensor: tensor to be distributed.
|
||||
dist_layout: the layout specification of the distributed tensor.
|
||||
tensor (`torch.Tensor`): tensor to be distributed.
|
||||
device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices.
|
||||
sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded.
|
||||
|
||||
Returns:
|
||||
A 'DTensor' object.
|
||||
'''
|
||||
return DTensor(local_tensor, dist_layout)
|
||||
return DTensor(tensor, device_mesh, sharding_spec)
|
||||
|
||||
|
||||
def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module:
|
||||
'''
|
||||
This function converts all the parameters in the module to DTensor(DParam).
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): the module to be distributed.
|
||||
partition_fn (callable): the partition function which will be used to partition the parameters.
|
||||
|
||||
Note: This function is subject to future change as the DParam has not been implemented yet.
|
||||
'''
|
||||
for name, param in module.named_parameters():
|
||||
|
@ -138,5 +206,11 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable]
|
|||
def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
|
||||
'''
|
||||
Construct the default sharding specification for the tensor.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): the tensor to be sharded.
|
||||
|
||||
Returns:
|
||||
A `ShardingSpec` object without any sharding specified.
|
||||
'''
|
||||
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
|
||||
|
|
|
@ -11,28 +11,32 @@ from .sharding_spec import ShardingSpec
|
|||
|
||||
|
||||
class Layout:
|
||||
"""Layout of a tensor.
|
||||
"""
|
||||
Layout of a tensor refers to the tensor placement on the device mesh and how the tensor is sharded over the devices.
|
||||
|
||||
Attributes:
|
||||
device_mesh: the device mesh to store the tensor distributed.
|
||||
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
|
||||
sharding_spec: the sharding specification to describe how the tensor is sharded.
|
||||
entire_shape: the entire shape of the global tensor.
|
||||
Args:
|
||||
device_mesh (`DeviceMesh`): the device mesh to store the tensor distributed.
|
||||
sharding_spec (`ShardingSpec`): the sharding specification to describe how the tensor is sharded.
|
||||
global_shape (`torch.Size`): the entire shape of the global tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec,
|
||||
entire_shape: torch.Size):
|
||||
def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
|
||||
self.device_mesh = device_mesh
|
||||
self.device_type = device_type
|
||||
self.sharding_spec = sharding_spec
|
||||
self.entire_shape = entire_shape
|
||||
self.global_shape = global_shape
|
||||
self._sanity_check()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(f'{self.sharding_spec}')
|
||||
|
||||
def get_sharded_shape_per_device(self):
|
||||
sharded_shape = list(self.entire_shape)
|
||||
def get_sharded_shape_per_device(self) -> torch.Size:
|
||||
"""
|
||||
Compute the shape of the sharded tensor on each device.
|
||||
|
||||
Returns:
|
||||
`torch.Size`: the shape of the sharded tensor on each device.
|
||||
"""
|
||||
sharded_shape = list(self.global_shape)
|
||||
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
|
||||
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
|
||||
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
||||
|
@ -56,7 +60,7 @@ class Layout:
|
|||
|
||||
# make sure that the sharding for a dimension is divisible by the number of devices
|
||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||
tensor_dim_size = self.entire_shape[dim]
|
||||
tensor_dim_size = self.global_shape[dim]
|
||||
num_devices = 1
|
||||
|
||||
for element in shard_list:
|
||||
|
|
|
@ -3,10 +3,8 @@ from copy import deepcopy
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.d_tensor.comm_spec import *
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
|
@ -28,13 +26,21 @@ class LayoutConverterOptions:
|
|||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
|
||||
def to_global(distributed_tensor: "DTensor", layout: Layout) -> torch.Tensor:
|
||||
"""
|
||||
Convert a distributed tensor to the global tensor with the given layout.
|
||||
This function returns a native `torch.Tensor` object.
|
||||
|
||||
|
||||
Args:
|
||||
distributed_tensor (`DTensor`): the distributed tensor to be converted.
|
||||
layout (`Layout`): the target layout specification.
|
||||
"""
|
||||
layout_converter = LayoutConverter()
|
||||
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
|
||||
global_layout = Layout(device_mesh=layout.device_mesh,
|
||||
device_type=layout.device_type,
|
||||
sharding_spec=global_sharding_spec,
|
||||
entire_shape=layout.entire_shape)
|
||||
global_shape=layout.global_shape)
|
||||
with torch.no_grad():
|
||||
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
|
||||
return global_tensor
|
||||
|
@ -49,6 +55,9 @@ def set_layout_converting_options(options: LayoutConverterOptions):
|
|||
|
||||
|
||||
class LayoutConverter(metaclass=SingletonMeta):
|
||||
"""
|
||||
LayoutConverter is a singleton class which converts the layout of a distributed tensor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._options = None
|
||||
|
@ -91,15 +100,14 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
# [S0,S1,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
rst_dict = layout_converter.all_gather_transform_layouts(layout)
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
|
@ -112,7 +120,12 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
|
||||
for target_pair in source_spec.dim_partition_dict.items():
|
||||
shard_list = all_gather_simulator(target_pair)
|
||||
index = target_pair[0]
|
||||
|
@ -130,7 +143,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
logical_process_axis = target_pair[1][-1]
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
process_groups_dict=process_groups_dict,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=gather_dim,
|
||||
# shard_dim will be used during backward
|
||||
shard_dim=gather_dim,
|
||||
|
@ -141,8 +154,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
global_shape=source_layout.global_shape)
|
||||
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
|
@ -167,15 +179,14 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
# [S0,S1,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
rst_dict = layout_converter.all_to_all_transform_layout(layout)
|
||||
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
|
@ -188,7 +199,12 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
|
||||
source_spec = source_layout.sharding_spec
|
||||
tensor_dims = source_spec.dims
|
||||
for f_index in range(tensor_dims - 1):
|
||||
|
@ -229,7 +245,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
shard_dim = f_index
|
||||
logical_process_axis = b_target_pair[1][-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_groups_dict,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=gather_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
|
@ -252,8 +268,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
global_shape=source_layout.global_shape)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
pass
|
||||
|
@ -278,16 +293,15 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
|
||||
dim_partition_dict = {0: [0]}
|
||||
|
||||
# [S0,R,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
rst_dict = layout_converter.shard_transform_layout(layout)
|
||||
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
|
@ -301,7 +315,11 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
|
||||
# legal sharding dims means the mesh_id is still available to use.
|
||||
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
|
||||
|
@ -329,7 +347,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
shard_dim = index
|
||||
logical_process_axis = shard_list[-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_groups_dict,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=shard_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
|
@ -340,8 +358,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
global_shape=source_layout.global_shape)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
pass
|
||||
|
@ -399,7 +416,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
|
||||
dim_partition_source = {1: [0, 1]}
|
||||
dim_partition_target = {0: [0, 1]}
|
||||
|
@ -407,16 +424,14 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
# [R,S01,R]
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
# [S01,R,R]
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
|
||||
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
|
||||
|
@ -505,21 +520,19 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
|
||||
# [S0,R,R]
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
# [R,S0,R]
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
if rank in (0, 1):
|
||||
sharded_tensor_0 = torch.zeros(2, 1)
|
||||
|
@ -554,3 +567,4 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
for comm_spec in comm_action_sequence:
|
||||
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||
return tensor
|
||||
return tensor
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
def test_device_mesh():
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
physical_mesh_id = torch.arange(0, 16)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
assert device_mesh.convert_map[5] == [1, 1]
|
||||
assert device_mesh.convert_map[11] == [2, 3]
|
||||
assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]]
|
||||
assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]]
|
||||
assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3]
|
||||
assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
|
||||
assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
|
||||
assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -20,16 +20,12 @@ def check_layer(rank, world_size, port):
|
|||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
|
||||
logical_process_groups = device_mesh.process_groups_dict
|
||||
|
||||
for mesh_dim, pgs in logical_pg_dict.items():
|
||||
for index, pg in enumerate(pgs):
|
||||
if rank in pg:
|
||||
tensor = torch.ones(4).cuda()
|
||||
group = logical_process_groups[mesh_dim][index][1]
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
|
||||
assert tensor.equal(tensor_to_check)
|
||||
for axis in range(len(mesh_shape)):
|
||||
tensor = torch.ones(4).cuda()
|
||||
pg = device_mesh.get_process_group(axis=axis)
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
|
||||
assert tensor.equal(tensor_to_check)
|
||||
|
||||
gpc.destroy()
|
||||
|
||||
|
|
|
@ -6,7 +6,9 @@ import numpy as np
|
|||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.layout_converter import to_global
|
||||
from tests.kit.model_zoo.registry import ModelAttribute
|
||||
|
||||
|
@ -81,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
|
|||
print(f'{model.__class__.__name__} pass')
|
||||
|
||||
|
||||
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
|
||||
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh,
|
||||
sharding_spec_dict: dict) -> None:
|
||||
state = model.state_dict()
|
||||
distributed_state = distributed_model.state_dict()
|
||||
|
||||
|
@ -91,6 +94,7 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
|
|||
assert n1 == n2
|
||||
t1 = t1.cuda()
|
||||
t2 = t2.cuda()
|
||||
if n2 in layout_dict:
|
||||
t2 = to_global(t2, layout_dict[n2])
|
||||
if n2 in sharding_spec_dict:
|
||||
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape)
|
||||
t2 = to_global(t2, layout)
|
||||
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
||||
|
|
|
@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]:
|
|||
return dim
|
||||
|
||||
|
||||
def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout:
|
||||
def make_sharding_spec(original_tensor: torch.Tensor) -> Layout:
|
||||
shard_dim = find_shard_dim(original_tensor.shape)
|
||||
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
|
||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=target_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
return layout
|
||||
return target_sharding_spec
|
||||
|
||||
|
||||
def _get_current_name(prefix: str, name: str) -> str:
|
||||
return f'{prefix}.{name}'.lstrip('.')
|
||||
|
||||
|
||||
def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
|
||||
layout_dict = {}
|
||||
def generate_sharding_spec_dict(model: nn.Module) -> dict:
|
||||
sharding_spec_dict = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_recursively(module: nn.Module, prefix: str = ''):
|
||||
|
@ -53,17 +49,17 @@ def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
|
|||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if isinstance(param, LazyTensor):
|
||||
layout = make_layout(device_mesh, param)
|
||||
layout_dict[_get_current_name(prefix, name)] = layout
|
||||
sharding_spec = make_sharding_spec(param)
|
||||
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
if isinstance(buf, LazyTensor):
|
||||
layout = make_layout(device_mesh, buf)
|
||||
layout_dict[_get_current_name(prefix, name)] = layout
|
||||
sharding_spec = make_sharding_spec(buf)
|
||||
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
|
||||
|
||||
generate_recursively(model)
|
||||
|
||||
return layout_dict
|
||||
return sharding_spec_dict
|
||||
|
||||
|
||||
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
|
||||
|
@ -85,9 +81,9 @@ def run_dist_lazy_init(subset, seed: int = 42):
|
|||
ctx = LazyInitContext()
|
||||
with ctx:
|
||||
deferred_model = model_fn()
|
||||
layout_dict = generate_layout_dict(deferred_model, device_mesh)
|
||||
ctx.distribute(deferred_model, layout_dict, verbose=True)
|
||||
assert_dist_model_equal(model, deferred_model, layout_dict)
|
||||
sharding_spec_dict = generate_sharding_spec_dict(deferred_model)
|
||||
ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True)
|
||||
assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port) -> None:
|
||||
|
|
|
@ -125,23 +125,6 @@ def check_all_reduce_bwd(process_groups_dict, rank):
|
|||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank):
|
||||
# tensor to comm
|
||||
tensor_to_comm = torch.ones(2, 2).cuda() * rank
|
||||
|
||||
# reduce through logical process axis 0 at flatten device mesh
|
||||
# tensor to check
|
||||
# tensor([[6., 6.],
|
||||
# [6., 6.]])
|
||||
tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
|
||||
|
||||
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_comm(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
@ -153,24 +136,22 @@ def check_comm(rank, world_size, port):
|
|||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
process_groups_dict = device_mesh.process_groups_dict
|
||||
|
||||
process_group_dict = device_mesh._process_group_dict[rank]
|
||||
|
||||
# test all gather
|
||||
check_all_gather(process_groups_dict, rank)
|
||||
check_all_gather(process_group_dict, rank)
|
||||
|
||||
# test shard
|
||||
check_shard(process_groups_dict, rank)
|
||||
check_shard(process_group_dict, rank)
|
||||
|
||||
# test all to all
|
||||
check_all_to_all(process_groups_dict, rank)
|
||||
check_all_to_all(process_group_dict, rank)
|
||||
|
||||
# test all reduce
|
||||
check_all_reduce_fwd(process_groups_dict, rank)
|
||||
check_all_reduce_bwd(process_groups_dict, rank)
|
||||
check_all_reduce_fwd(process_group_dict, rank)
|
||||
check_all_reduce_bwd(process_group_dict, rank)
|
||||
|
||||
flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict
|
||||
# test all reduce in 1D flatten device mesh
|
||||
check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
|
|
|
@ -31,13 +31,9 @@ def check_dtensor(rank, world_size, port):
|
|||
|
||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=target_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
d_tensor = DTensor(original_tensor, layout)
|
||||
d_tensor = DTensor(original_tensor, device_mesh, target_sharding_spec)
|
||||
|
||||
assert d_tensor.entire_shape == original_tensor.shape
|
||||
assert d_tensor.global_shape == original_tensor.shape
|
||||
assert d_tensor.data_type == original_tensor.dtype
|
||||
|
||||
if rank in (0, 1):
|
||||
|
@ -57,12 +53,7 @@ def check_dtensor(rank, world_size, port):
|
|||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
|
||||
new_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=new_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
|
||||
d_tensor.layout_convert(new_layout)
|
||||
d_tensor.layout_convert(device_mesh, new_sharding_spec)
|
||||
|
||||
if rank == 0:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
|
@ -75,7 +66,7 @@ def check_dtensor(rank, world_size, port):
|
|||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
|
||||
dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec)
|
||||
|
||||
if rank == 0:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
|
|
|
@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
|
|||
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
entire_shape = torch.Size((64, 32, 16))
|
||||
global_shape = torch.Size((64, 32, 16))
|
||||
layout_converter = LayoutConverter()
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
|
||||
|
||||
|
@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port):
|
|||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (2, 2)
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
||||
|
||||
rst_dict = layout_converter.all_gather_transform_layouts(layout)
|
||||
|
||||
|
@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port):
|
|||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
|
||||
layout_all2all = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_all2all,
|
||||
entire_shape=entire_shape)
|
||||
layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape)
|
||||
|
||||
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
|
||||
|
||||
|
@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port):
|
|||
# shard_sequence: S0,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
|
||||
shard_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_shard,
|
||||
entire_shape=entire_shape)
|
||||
shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape)
|
||||
|
||||
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
|
||||
|
||||
|
@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port):
|
|||
# shard_sequence: R,S01,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
|
||||
|
||||
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
|
||||
|
||||
|
@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port):
|
|||
# shard_sequence: R,S01,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
|
||||
|
||||
original_tensor = torch.rand(entire_shape).cuda()
|
||||
original_tensor = torch.rand(global_shape).cuda()
|
||||
|
||||
# tensor_to_apply: [R, S01, R]
|
||||
tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
|
||||
import torch
|
||||
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
physical_mesh_id = torch.arange(0, 16)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
|
|
|
@ -26,7 +26,7 @@ def run_dist(rank, world_size, port):
|
|||
# the mesh is in the following topo
|
||||
# [[0, 1],
|
||||
# [2, 3]]
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
row_id = rank // 2
|
||||
|
|
|
@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
|||
|
||||
|
||||
def test_sharding_spec():
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
physical_mesh_id = torch.arange(0, 16)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
|
|
Loading…
Reference in New Issue