mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added sharding spec conversion for linear handler (#1687)
parent
af718e83f2
commit
4973157ad7
@ -0,0 +1,68 @@
|
||||
import torch
|
||||
from typing import Dict
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
||||
"""
|
||||
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
|
||||
|
||||
Args:
|
||||
sharding_spec (ShardingSpec): the sharding spec for which partition dim are switched
|
||||
dim1 (int): the tensor dimension to switch
|
||||
dim2 (int): the tensor dimension to switch
|
||||
"""
|
||||
assert len(sharding_spec.entire_shape) == 2
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
dim1_partition = dim_partition_dict.pop(dim1, None)
|
||||
dim2_partition = dim_partition_dict.pop(dim2, None)
|
||||
|
||||
if dim1_partition:
|
||||
dim_partition_dict[dim2] = dim1_partition
|
||||
|
||||
if dim2_partition:
|
||||
dim_partition_dict[dim1] = dim2_partition
|
||||
|
||||
# re-init the sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
|
||||
def update_partition_dim(sharding_spec: ShardingSpec,
|
||||
dim_mapping: Dict[int, int],
|
||||
physical_shape: torch.Size,
|
||||
inplace: bool = False):
|
||||
"""
|
||||
This method is used to update the partition dim dict from the logical one to the physical one.
|
||||
|
||||
Args:
|
||||
sharding_spec (ShardingSpec): the sharding spec for which partition dims are updated
|
||||
dim_mapping (Dict[int, int]): the mapping from the logical tensor dimension to the physical tensor dimension
|
||||
physical_shape (torch.Size): the physical shape for the tensor
|
||||
"""
|
||||
|
||||
if inplace:
|
||||
current_sharding_spec = sharding_spec
|
||||
else:
|
||||
current_sharding_spec = deepcopy(sharding_spec)
|
||||
|
||||
old_dim_partition_dict = current_sharding_spec.dim_partition_dict
|
||||
new_dim_partition_dict = {}
|
||||
|
||||
# assign new dim
|
||||
for old_dim, new_dim in dim_mapping.items():
|
||||
mesh_dims = old_dim_partition_dict.pop(old_dim)
|
||||
new_dim_partition_dict[new_dim] = mesh_dims
|
||||
|
||||
for tensor_dim, mesh_dims in old_dim_partition_dict.items():
|
||||
if tensor_dim in new_dim_partition_dict:
|
||||
raise KeyError(f"There are duplicated entries for the tensor sharding dimension {tensor_dim}")
|
||||
else:
|
||||
new_dim_partition_dict[tensor_dim] = mesh_dims
|
||||
|
||||
# update sharding spec
|
||||
current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh,
|
||||
entire_shape=physical_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
return current_sharding_spec
|
Loading…
Reference in new issue