[autoparallel] mix gather (#1977)

* Add mix-gather

* Add comments

* Add comments

* Polish comments

* Change the global rank assumption

* Add tests

* Add two-step tests

* Fix 10 and 01

* Skip test becasue the number of GPUs
pull/2018/head
Genghan Zhang 2022-11-23 21:49:17 +08:00 committed by GitHub
parent 7242bffc5f
commit d655eea515
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 617 additions and 4 deletions

View File

@ -52,6 +52,9 @@ class DeviceMesh:
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
if self.need_flatten:
self.flatten_device_mesh = self.flatten()
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
self.mesh_beta)
@property
def shape(self):
@ -199,3 +202,38 @@ class DeviceMesh:
penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
class FlattenDeviceMesh(DeviceMesh):
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
super().__init__(physical_mesh_id,
mesh_shape,
mesh_alpha,
mesh_beta,
init_process_group=False,
need_flatten=False)
# Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
self.mesh_alpha = max(self.mesh_alpha)
self.mesh_beta = min(self.mesh_beta)
# Different from original process_groups_dict, rank_list is not stored
self.process_number_dict = self.create_process_numbers_for_logical_mesh()
def create_process_numbers_for_logical_mesh(self):
'''
Build 1d DeviceMesh in column-major(0) and row-major(1)
for example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
num_devices = reduce(operator.mul, self.mesh_shape, 1)
process_numbers_dict = {}
process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
return process_numbers_dict
def mix_gather_cost(self, num_bytes):
num_devices = reduce(operator.mul, self.mesh_shape, 1)
return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)

View File

@ -79,6 +79,132 @@ def _all_reduce(tensor, comm_spec, async_op=False):
return tensor
def _mix_gather(tensor, comm_spec):
'''
Implement mix gather operation on device mesh based on information provided by comm_spec.
Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is
different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather
only does all-gather in one dimension.
Assume index of f and b target pairs are 'f' and 'b'
ShardingSpec => gather_dim, logical_process_axes
S0S1 => [b, f], (1, 0)
S1S0 => [b, f], (0, 1)
S01R => [f], (1, 1)
RS01 => [b], (1, 1)
Example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
S0S1:
leading_group_dim = 1
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
tensor_list = [(0,0),(0,1),(0,2),(0,3),(1,0),(1,1),(1,2),(1,3)] # [(slice_id_f, slice_id_b),...]
mesh_shape = (2,4)
cat_slice = [4,2]
tmp_tensor_list = [(...,shape[f],shape[b]*4,...),(...,shape[f],shape[b]*4,...)]
tmp_tensor_list[0] = torch.cat(((0,0),(0,1),(0,2),(0,3)), dim=b)
tmp_tensor_list[1] = torch.cat(((1,0),(1,1),(1,2),(1,3)), dim=b)
output = torch.cat((tmp_tensor_list[0],tmp_tensor_list[1]), dim=a)
S1S0:
leading_group_dim = 0
process_group = "[0, 4, 1, 5, 2, 6, 3, 7]"
tensor_list = [(0,0),(0,1),(1,0),(1,1),(2,0),(2,1),(3,0),(3,1)]
mesh_shape = (2,4)
cat_slice = [2,4]
tmp_tensor_list = [(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...)]
tmp_tensor_list[0] = torch.cat(((0,0),(0,1)), dim=b)
tmp_tensor_list[1] = torch.cat(((1,0),(1,1)), dim=b)
tmp_tensor_list[2] = torch.cat(((2,0),(2,1)), dim=b)
tmp_tensor_list[3] = torch.cat(((3,0),(3,1)), dim=b)
S10R:
leading_group_dim = 0
process_group = "[0, 4, 1, 5, 2, 6, 3, 7]"
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
S01R:
leading_group_dim = 1
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
'''
total_slices = comm_spec.device_mesh.mesh_shape[0]
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)]
leading_group_dim = comm_spec.logical_process_axes[0]
assert len(comm_spec.device_mesh.process_groups_dict) == 1
_, process_group = comm_spec.device_mesh.process_groups_dict[0][0]
process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim]
# Global all_gather
dist.all_gather(tensor_list, tensor, group=process_group)
# This is very ugly. I'm figuring out more elegant methods
tensor_list_sorted = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)
]
for i in range(total_slices):
tensor_list_sorted[i] = tensor_list[process_number_list[i]]
tensor_list = tensor_list_sorted
if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]:
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous()
else:
mesh_shape = comm_spec.device_meshes.mesh_shape
cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]]
tmp_tensor_shape = list(tensor.shape)
tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0]
tmp_tensor_shape = torch.Size(tmp_tensor_shape)
tmp_tensor_list = [
torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1])
]
for i in range(cat_slice[1]):
tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]),
comm_spec.gather_dim[0]).contiguous()
output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous()
return output
def _mix_split(tensor, comm_spec):
'''
Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent)
Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split
because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension.
Assume index of f and b target pairs are 'f' and 'b'
S0S1 => [b, f], (1, 0)
S1S0 => [b, f], (0, 1)
S01R => [f], (0, 0)
RS01 => [b], (0, 0)
Example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
mesh_shape = comm_spec.device_meshes.mesh_shape
dim = comm_spec.gather_dim
total_slices = comm_spec.device_mesh.mesh_shape[0]
# Get global rank
rank = dist.get_rank()
leading_group_dim = comm_spec.logical_process_axes[0]
process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim]
rank = process_number_list.index(rank)
if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]:
length = tensor.shape[dim[0]] // total_slices
start = length * rank
output = torch.narrow(tensor, dim[0], start, length).contiguous()
else:
tensor_shape = [tensor.shape[dim[0]], tensor.shape[dim[1]]]
rank_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]]
length = [tensor_shape[0] // rank_slice[0], tensor_shape[1] // rank_slice[1]]
start = [(rank % rank_slice[0]) * length[0], (rank // rank_slice[0]) * length[1]]
tmp_output = torch.narrow(tensor, dim[0], start[0], length[0]).contiguous()
output = torch.narrow(tmp_output, dim[1], start[1], length[1]).contiguous()
return output
class _ReduceGrad(torch.autograd.Function):
"""
A customized communication operation which forward is an identity operation,
@ -204,6 +330,22 @@ class _AllToAll(torch.autograd.Function):
return _all_to_all(grad_outputs, ctx.comm_spec), None
class _MixGatherForwardMixSplitBackward(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
return _mix_gather(input_)
@staticmethod
def forward(ctx, input_, comm_spec):
ctx.comm_spec = comm_spec
return _mix_gather(input_, comm_spec)
@staticmethod
def backward(ctx, grad_output):
return _mix_split(grad_output, ctx.comm_spec), None
def reduce_grad(input_, comm_spec):
return _ReduceGrad.apply(input_, comm_spec)
@ -224,12 +366,17 @@ def all_to_all(input_, comm_spec):
return _AllToAll.apply(input_, comm_spec)
def mixgather_forward_split_backward(input_, comm_spec):
return _MixGatherForwardMixSplitBackward.apply(input_, comm_spec)
class CollectiveCommPattern(Enum):
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd"
class CommSpec:
@ -255,7 +402,8 @@ class CommSpec:
gather_dim=None,
shard_dim=None,
logical_process_axis=None,
forward_only=False):
forward_only=False,
mix_gather=False):
self.comm_pattern = comm_pattern
self.sharding_spec = sharding_spec
self.gather_dim = gather_dim
@ -263,8 +411,14 @@ class CommSpec:
self.logical_process_axis = logical_process_axis
self.forward_only = forward_only
if isinstance(self.logical_process_axis, list):
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.logical_process_axis = 0
if not mix_gather:
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.logical_process_axis = 0
else:
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
# Create a new member `logical_process_axes` to distinguish from original flatten
self.logical_process_axes = logical_process_axis
else:
self.device_mesh = self.sharding_spec.device_mesh
@ -289,6 +443,10 @@ class CommSpec:
elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
res_list.append(f"comm_pattern:MIXGATHER_FWD_SPLIT_BWD, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"logical_process_asex:{self.logical_process_axes})")
return ''.join(res_list)
@ -324,6 +482,11 @@ class CommSpec:
forward_communication_cost = 10
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
# no need for axis because all devices are used in mix_gather
forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size)
backward_communication_cost = 10
if self.forward_only:
cost_dict["forward"] = forward_communication_cost
cost_dict["backward"] = 0
@ -356,4 +519,5 @@ pattern_to_func_dict = {
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward,
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input,
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad,
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: mixgather_forward_split_backward,
}

View File

@ -7,7 +7,7 @@ import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator
from .comm_spec import *
@ -328,6 +328,59 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass
return valid_spec_dict
def get_all_mix_gather_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
'''
S0S1 -> RR
S1S0 -> RR
S01R -> RR
RS01 -> RR
'''
valid_spec_dict = {}
comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
tensor_dims = len(source_spec.entire_shape)
for f_index in range(tensor_dims - 1):
for b_index in range(f_index + 1, tensor_dims):
if (f_index not in source_spec.dim_partition_dict) and (b_index not in source_spec.dim_partition_dict):
continue
else:
if f_index in source_spec.dim_partition_dict:
# skip (S10, R) -> (R, R)
if len(f_target_pair[1]) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]:
continue
f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
else:
f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict:
# skip (R, S10) -> (R, R)
if len(b_target_pair[1]) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]:
continue
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
else:
b_target_pair = (b_index, [])
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
comm_spec = CommSpec(comm_pathern,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
forward_only=self.forward_only,
mix_gather=True)
cost_dict = comm_spec.get_comm_cost()
new_dim_partition_dict = {}
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
except ShardingSpecException:
pass
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
'''
Get all valid sharding specs from source_spec with one step transform, and

View File

@ -90,6 +90,31 @@ def shard_simulator(target_pair, legal_sharding_dims):
return shard_list_list
def mix_gather_simulator(f_target_pair, b_target_pair):
'''
Assume index of f and b target pairs are 'f' and 'b'
S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0)
S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1)
S01R => Input: (f, [0, 1]), (b, []) Output: [f], (1, 1)
RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1)
S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0)
RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0)
'''
if f_target_pair[1] and b_target_pair[1]:
leading_dim = b_target_pair[1] > f_target_pair[1]
return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)]
if f_target_pair[1]:
leading_dim = f_target_pair[1][0] < f_target_pair[1][1]
return [
f_target_pair[0],
], [int(leading_dim), int(leading_dim)]
if b_target_pair[1]:
leading_dim = b_target_pair[1][0] < b_target_pair[1][1]
return [
b_target_pair[0],
], [int(leading_dim), int(leading_dim)]
# The function is credited to PyTorch Team
def named_params_with_colotensor(
module: nn.Module,

View File

@ -0,0 +1,333 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.utils import mix_gather_simulator
from colossalai.utils import free_port
def check_mix_gather_S0S1(device_mesh, rank):
tensor_to_check = torch.arange(64).reshape((8, 8)).cuda()
(f, b) = (0, 1)
f_target_pair = (f, [0])
b_target_pair = (b, [1])
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
tensor_slice = [4, 2] # (4, 2)
rank_slice = 4
f_start = (rank // rank_slice) * tensor_slice[0]
b_start = (rank % rank_slice) * tensor_slice[1]
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
b_start:b_start + tensor_slice[1]].contiguous().cuda()
dim_partition_dict = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1
# device_mesh_shape: (2, 4)
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
forward_only=True,
mix_gather=True)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_two_all_gather_S0S1(device_mesh, rank):
tensor_width = 8
tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda()
dim_partition_dict = {0: [0], 1: [1]}
tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2)
rank_slice = 4
f_start = (rank // rank_slice) * tensor_slice[0]
b_start = (rank % rank_slice) * tensor_slice[1]
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
b_start:b_start + tensor_slice[1]].contiguous().cuda()
# DistSpec:
# shard_sequence: S0,S1
# device_mesh_shape: (2, 4)
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0)
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
sharding_spec,
gather_dim=0,
logical_process_axis=0)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
dim_partition_dict = {1: [1]}
# DistSpec:
# shard_sequence: R,S1
# device_mesh_shape: (2, 4)
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
sharding_spec,
gather_dim=1,
logical_process_axis=1)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_mix_gather_S1S0(device_mesh, rank):
tensor_to_check = torch.arange(64).reshape((8, 8)).cuda()
(f, b) = (0, 1)
f_target_pair = (f, [1])
b_target_pair = (b, [0])
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
tensor_slice = [2, 4]
rank_slice = 4
f_start = (rank % rank_slice) * tensor_slice[0]
b_start = (rank // rank_slice) * tensor_slice[1]
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
b_start:b_start + tensor_slice[1]].contiguous().cuda()
dim_partition_dict = {0: [1], 1: [0]}
# DistSpec:
# shard_sequence: S1,S0
# device_mesh_shape: (2, 4)
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
forward_only=True,
mix_gather=True)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_two_all_gather_S1S0(device_mesh, rank):
tensor_width = 8
tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda()
tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2)
rank_slice = 4
f_start = (rank % rank_slice) * tensor_slice[0]
b_start = (rank // rank_slice) * tensor_slice[1]
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
b_start:b_start + tensor_slice[1]].contiguous().cuda()
dim_partition_dict = {0: [1], 1: [0]}
# DistSpec:
# shard_sequence: S1,S0
# device_mesh_shape: (2, 4)
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
sharding_spec,
gather_dim=0,
logical_process_axis=1)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
dim_partition_dict = {1: [0]}
# DistSpec:
# shard_sequence: R,S0
# device_mesh_shape: (2, 4)
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0)
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
sharding_spec,
gather_dim=1,
logical_process_axis=0)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_mix_gather_S01R(device_mesh, rank):
tensor_to_check = torch.arange(64).reshape((8, 8)).cuda()
(f, b) = (0, 1)
f_target_pair = (f, [0, 1])
b_target_pair = (b, [])
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
tensor_to_comm = tensor_to_check[rank:rank + 1, :].contiguous().cuda()
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R
# device_mesh_shape: (2, 4)
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
forward_only=True,
mix_gather=True)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_two_all_gather_S01R(device_mesh, rank):
tensor_width = 8
tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda()
rank_stride = tensor_width // 8
tensor_to_comm = tensor_to_check[rank:rank + rank_stride, :].contiguous().cuda()
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01, R
# device_mesh_shape: (2, 4)
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0)
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
sharding_spec,
gather_dim=0,
logical_process_axis=1)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
dim_partition_dict = {0: [0]}
# DistSpec:
# shard_sequence: S1, R
# device_mesh_shape: (2, 4)
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
sharding_spec,
gather_dim=0,
logical_process_axis=0)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_mix_gather_RS01(device_mesh, rank):
tensor_to_check = torch.arange(64).reshape((8, 8)).cuda()
(f, b) = (0, 1)
f_target_pair = (f, [])
b_target_pair = (b, [0, 1])
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
tensor_to_comm = tensor_to_check[:, rank:rank + 1].contiguous().cuda()
dim_partition_dict = {1: [0, 1]}
# DistSpec:
# shard_sequence: R, S01
# device_mesh_shape: (2, 4)
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
forward_only=True,
mix_gather=True)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_two_all_gather_RS01(device_mesh, rank):
tensor_width = 8
tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda()
rank_stride = tensor_width // 8
tensor_to_comm = tensor_to_check[:, rank:rank + rank_stride].contiguous().cuda()
dim_partition_dict = {1: [0, 1]}
# DistSpec:
# shard_sequence: R, S01
# device_mesh_shape: (2, 4)
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0)
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
sharding_spec,
gather_dim=1,
logical_process_axis=1)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
dim_partition_dict = {1: [0]}
# DistSpec:
# shard_sequence: R, S1
# device_mesh_shape: (2, 4)
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
sharding_spec,
gather_dim=1,
logical_process_axis=0)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_comm(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 8)
assert rank == gpc.get_global_rank()
mesh_shape = (2, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True, need_flatten=True)
check_mix_gather_S0S1(device_mesh, rank)
check_two_all_gather_S0S1(device_mesh, rank)
check_mix_gather_S1S0(device_mesh, rank)
check_two_all_gather_S1S0(device_mesh, rank)
check_mix_gather_S01R(device_mesh, rank)
check_two_all_gather_S01R(device_mesh, rank)
check_mix_gather_RS01(device_mesh, rank)
check_two_all_gather_RS01(device_mesh, rank)
@pytest.mark.skip(reason="Skip because the check functions assume 8 GPUS but CI only have 4 GPUs")
def test_mix_gather():
world_size = 8
run_func = partial(check_comm, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_mix_gather()