mirror of https://github.com/hpcaitech/ColossalAI
122 lines
4.3 KiB
Python
122 lines
4.3 KiB
Python
import operator
|
|
from copy import deepcopy
|
|
from functools import reduce
|
|
from typing import Dict
|
|
|
|
import torch
|
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
|
|
|
__all__ = [
|
|
"transpose_partition_dim",
|
|
"update_partition_dim",
|
|
"enumerate_all_possible_1d_sharding",
|
|
"enumerate_all_possible_2d_sharding",
|
|
"generate_sharding_size",
|
|
]
|
|
|
|
|
|
def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
|
"""
|
|
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
|
|
|
|
Args:
|
|
sharding_spec (ShardingSpec): the sharding spec for which partition dim are switched
|
|
dim1 (int): the tensor dimension to switch
|
|
dim2 (int): the tensor dimension to switch
|
|
"""
|
|
assert len(sharding_spec.entire_shape) >= 2, "The entire_shape of the sharding spec must have at least 2 dimensions"
|
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
|
|
|
# transpose the dim partition
|
|
dim1_partition = dim_partition_dict.pop(dim1, None)
|
|
dim2_partition = dim_partition_dict.pop(dim2, None)
|
|
|
|
if dim1_partition:
|
|
dim_partition_dict[dim2] = dim1_partition
|
|
if dim2_partition:
|
|
dim_partition_dict[dim1] = dim2_partition
|
|
|
|
# get the transposed shape
|
|
new_shape = list(sharding_spec.entire_shape[:])
|
|
new_shape[dim2], new_shape[dim1] = new_shape[dim1], new_shape[dim2]
|
|
new_shape = torch.Size(new_shape)
|
|
|
|
# re-init the sharding spec
|
|
sharding_spec.__init__(sharding_spec.device_mesh, new_shape, dim_partition_dict)
|
|
return sharding_spec
|
|
|
|
|
|
def update_partition_dim(
|
|
sharding_spec: ShardingSpec, dim_mapping: Dict[int, int], physical_shape: torch.Size, inplace: bool = False
|
|
):
|
|
"""
|
|
This method is used to update the partition dim dict from the logical one to the physical one.
|
|
|
|
Args:
|
|
sharding_spec (ShardingSpec): the sharding spec for which partition dims are updated
|
|
dim_mapping (Dict[int, int]): the mapping from the logical tensor dimension to the physical tensor dimension
|
|
physical_shape (torch.Size): the physical shape for the tensor
|
|
"""
|
|
|
|
if inplace:
|
|
current_sharding_spec = sharding_spec
|
|
else:
|
|
current_sharding_spec = deepcopy(sharding_spec)
|
|
|
|
old_dim_partition_dict = current_sharding_spec.dim_partition_dict
|
|
new_dim_partition_dict = {}
|
|
|
|
# assign new dim
|
|
for old_dim, new_dim in dim_mapping.items():
|
|
mesh_dims = old_dim_partition_dict.pop(old_dim)
|
|
new_dim_partition_dict[new_dim] = mesh_dims
|
|
|
|
for tensor_dim, mesh_dims in old_dim_partition_dict.items():
|
|
if tensor_dim in new_dim_partition_dict:
|
|
raise KeyError(f"There are duplicated entries for the tensor sharding dimension {tensor_dim}")
|
|
else:
|
|
new_dim_partition_dict[tensor_dim] = mesh_dims
|
|
|
|
# update sharding spec
|
|
current_sharding_spec.__init__(
|
|
device_mesh=sharding_spec.device_mesh, entire_shape=physical_shape, dim_partition_dict=new_dim_partition_dict
|
|
)
|
|
return current_sharding_spec
|
|
|
|
|
|
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
|
|
dim_partition_list = []
|
|
# enumerate all the 2D sharding cases
|
|
for i in range(dim_size):
|
|
for j in range(i + 1, dim_size):
|
|
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
|
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
|
dim_partition_list.append(dim_partition_dict_0)
|
|
dim_partition_list.append(dim_partition_dict_1)
|
|
for i in range(dim_size):
|
|
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
|
dim_partition_list.append(dim_partition_dict_flatten)
|
|
|
|
return dim_partition_list
|
|
|
|
|
|
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
|
|
dim_partition_list = []
|
|
# enumerate all the 1D sharding cases
|
|
for i in range(dim_size):
|
|
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
|
dim_partition_list.append(dim_partition_dict_0)
|
|
|
|
return dim_partition_list
|
|
|
|
|
|
def generate_sharding_size(dim_partition_dict, device_mesh):
|
|
total_sharding_size = 1
|
|
for mesh_dim_list in dim_partition_dict.values():
|
|
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
|
|
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
|
|
total_sharding_size *= sharding_size
|
|
|
|
return total_sharding_size
|