mirror of https://github.com/hpcaitech/ColossalAI
Frank Lee
2 years ago
committed by
GitHub
2 changed files with 155 additions and 0 deletions
@ -0,0 +1,96 @@ |
|||||||
|
import torch |
||||||
|
from enum import Enum, auto |
||||||
|
from typing import List |
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec |
||||||
|
|
||||||
|
__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape'] |
||||||
|
|
||||||
|
|
||||||
|
class BroadcastType(Enum): |
||||||
|
EQUAL = auto() |
||||||
|
PADDDING = auto() |
||||||
|
MULTIPLE = auto() |
||||||
|
|
||||||
|
|
||||||
|
def is_broadcastable(shape1: torch.Size, shape2: torch.Size) -> bool: |
||||||
|
""" |
||||||
|
Check if two shapes are broadcastable to each other. |
||||||
|
""" |
||||||
|
for s1, s2 in zip(shape1[::-1], shape2[::-1]): |
||||||
|
if s1 == 1 or s2 == 1 or s1 == s2: |
||||||
|
pass |
||||||
|
else: |
||||||
|
return False |
||||||
|
return True |
||||||
|
|
||||||
|
|
||||||
|
def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]: |
||||||
|
""" |
||||||
|
Compute the broadcast shape given two shapes. |
||||||
|
""" |
||||||
|
assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable' |
||||||
|
shape1_reverse = shape1[::-1] |
||||||
|
shape2_reverse = shape2[::-1] |
||||||
|
min_common_dim = min(len(shape1), len(shape2)) |
||||||
|
dims = [] |
||||||
|
for s1, s2 in zip(shape1_reverse, shape2_reverse): |
||||||
|
dims.append(max(s1, s2)) |
||||||
|
|
||||||
|
# append the remaining dims |
||||||
|
dims.extend(shape1_reverse[min_common_dim:]) |
||||||
|
dims.extend(shape2_reverse[min_common_dim:]) |
||||||
|
return dims[::-1] |
||||||
|
|
||||||
|
|
||||||
|
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, |
||||||
|
physical_shape: torch.Size) -> ShardingSpec: |
||||||
|
""" |
||||||
|
This function computes the sharding spec for the physical shape of a broadcast tensor. |
||||||
|
|
||||||
|
Args: |
||||||
|
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor |
||||||
|
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor |
||||||
|
physical_shape (torch.Size): the shape of the tensor before broadcasting |
||||||
|
""" |
||||||
|
# get the number of dimensions |
||||||
|
logical_num_dims = len(logical_shape) |
||||||
|
physical_num_dims = len(physical_shape) |
||||||
|
|
||||||
|
# track the dim and its broadcasting type |
||||||
|
logical_dim_broadcast_info = {} |
||||||
|
|
||||||
|
for i in range(logical_num_dims): |
||||||
|
# get the trailing dim size |
||||||
|
logical_dim_idx = logical_num_dims - i - 1 |
||||||
|
phyiscal_dim_idx = physical_num_dims - i - 1 |
||||||
|
logical_dim_size = logical_shape[logical_dim_idx] |
||||||
|
|
||||||
|
if phyiscal_dim_idx >= 0: |
||||||
|
physical_dim_size = physical_shape[phyiscal_dim_idx] |
||||||
|
|
||||||
|
if physical_dim_size == logical_dim_size: |
||||||
|
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL |
||||||
|
elif physical_dim_size == 1 and physical_dim_size != logical_dim_size: |
||||||
|
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE |
||||||
|
else: |
||||||
|
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING |
||||||
|
|
||||||
|
# generate the sharding spec for the physical shape |
||||||
|
physical_dim_partition = {} |
||||||
|
logical_dim_partition = logical_sharding_spec.dim_partition_dict |
||||||
|
|
||||||
|
for shape_dim, mesh_dim in logical_dim_partition.items(): |
||||||
|
logical_broadcast_type = logical_dim_broadcast_info[shape_dim] |
||||||
|
|
||||||
|
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE: |
||||||
|
pass |
||||||
|
else: |
||||||
|
# get the corresponding physical dim |
||||||
|
physical_dim = physical_num_dims - (logical_num_dims - shape_dim) |
||||||
|
physical_dim_partition[physical_dim] = mesh_dim |
||||||
|
|
||||||
|
physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh, |
||||||
|
entire_shape=physical_shape, |
||||||
|
dim_partition_dict=physical_dim_partition) |
||||||
|
|
||||||
|
return physical_sharding_spec |
@ -0,0 +1,59 @@ |
|||||||
|
import torch |
||||||
|
from colossalai.auto_parallel.solver.op_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape |
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec |
||||||
|
from colossalai.device.device_mesh import DeviceMesh |
||||||
|
|
||||||
|
|
||||||
|
def test_is_broadcastable(): |
||||||
|
x1 = torch.rand(4, 4, 8) |
||||||
|
x2 = torch.rand(1, 8) |
||||||
|
assert is_broadcastable(x1.shape, x2.shape) |
||||||
|
|
||||||
|
x1 = torch.rand(4, 2, 8) |
||||||
|
x2 = torch.rand(2, 8) |
||||||
|
assert is_broadcastable(x1.shape, x2.shape) |
||||||
|
|
||||||
|
x1 = torch.rand(4, 2, 8) |
||||||
|
x2 = torch.rand(4, 8) |
||||||
|
assert not is_broadcastable(x1.shape, x2.shape) |
||||||
|
|
||||||
|
|
||||||
|
def test_get_broadcast_shape(): |
||||||
|
x1 = torch.rand(4, 4, 8) |
||||||
|
x2 = torch.rand(1, 8) |
||||||
|
assert get_broadcast_shape(x1.shape, x2.shape) == [4, 4, 8] |
||||||
|
|
||||||
|
x1 = torch.rand(4, 2, 8) |
||||||
|
x2 = torch.rand(2, 8) |
||||||
|
assert get_broadcast_shape(x1.shape, x2.shape) == [4, 2, 8] |
||||||
|
|
||||||
|
x1 = torch.rand(4, 2, 8) |
||||||
|
x2 = torch.rand(8) |
||||||
|
assert get_broadcast_shape(x1.shape, x2.shape) == [4, 2, 8] |
||||||
|
|
||||||
|
|
||||||
|
def test_recover_sharding_spec_for_broadcast_shape(): |
||||||
|
x1 = torch.rand(4, 1, 8) |
||||||
|
x2 = torch.rand(2, 8) |
||||||
|
|
||||||
|
physical_mesh_id = torch.arange(0, 4) |
||||||
|
mesh_shape = (2, 2) |
||||||
|
# [[0, 1] |
||||||
|
# [2, 3]] |
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) |
||||||
|
|
||||||
|
broadcast_shape = get_broadcast_shape(x1.shape, x2.shape) |
||||||
|
logical_sharding_spec_for_x1 = ShardingSpec(device_mesh=device_mesh, |
||||||
|
dim_partition_dict={ |
||||||
|
0: [0], |
||||||
|
1: [1] |
||||||
|
}, |
||||||
|
entire_shape=broadcast_shape) |
||||||
|
physical_sharding_spec_for_x1 = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec_for_x1, |
||||||
|
broadcast_shape, x1.shape) |
||||||
|
print(physical_sharding_spec_for_x1) |
||||||
|
|
||||||
|
assert physical_sharding_spec_for_x1.entire_shape == x1.shape |
||||||
|
# dim 1 for the physical tensor is of broadcast type MULTIPLE, so should ignore |
||||||
|
assert physical_sharding_spec_for_x1.dim_partition_dict == {0: [0]} |
||||||
|
assert physical_sharding_spec_for_x1.sharding_sequence == ['S0', 'R', 'R'] |
Loading…
Reference in new issue