mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] mix gather (#1977)
* Add mix-gather * Add comments * Add comments * Polish comments * Change the global rank assumption * Add tests * Add two-step tests * Fix 10 and 01 * Skip test becasue the number of GPUspull/2018/head
parent
7242bffc5f
commit
d655eea515
|
@ -52,6 +52,9 @@ class DeviceMesh:
|
|||
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
|
||||
if self.need_flatten:
|
||||
self.flatten_device_mesh = self.flatten()
|
||||
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
|
||||
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
|
||||
self.mesh_beta)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
@ -199,3 +202,38 @@ class DeviceMesh:
|
|||
penalty_factor = num_devices / 2.0
|
||||
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
|
||||
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
|
||||
|
||||
|
||||
class FlattenDeviceMesh(DeviceMesh):
|
||||
|
||||
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
|
||||
super().__init__(physical_mesh_id,
|
||||
mesh_shape,
|
||||
mesh_alpha,
|
||||
mesh_beta,
|
||||
init_process_group=False,
|
||||
need_flatten=False)
|
||||
# Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
|
||||
self.mesh_alpha = max(self.mesh_alpha)
|
||||
self.mesh_beta = min(self.mesh_beta)
|
||||
# Different from original process_groups_dict, rank_list is not stored
|
||||
self.process_number_dict = self.create_process_numbers_for_logical_mesh()
|
||||
|
||||
def create_process_numbers_for_logical_mesh(self):
|
||||
'''
|
||||
Build 1d DeviceMesh in column-major(0) and row-major(1)
|
||||
for example:
|
||||
mesh_shape = (2,4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7]]
|
||||
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
|
||||
'''
|
||||
num_devices = reduce(operator.mul, self.mesh_shape, 1)
|
||||
process_numbers_dict = {}
|
||||
process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
|
||||
process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
|
||||
return process_numbers_dict
|
||||
|
||||
def mix_gather_cost(self, num_bytes):
|
||||
num_devices = reduce(operator.mul, self.mesh_shape, 1)
|
||||
return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)
|
||||
|
|
|
@ -79,6 +79,132 @@ def _all_reduce(tensor, comm_spec, async_op=False):
|
|||
return tensor
|
||||
|
||||
|
||||
def _mix_gather(tensor, comm_spec):
|
||||
'''
|
||||
Implement mix gather operation on device mesh based on information provided by comm_spec.
|
||||
Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is
|
||||
different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather
|
||||
only does all-gather in one dimension.
|
||||
Assume index of f and b target pairs are 'f' and 'b'
|
||||
ShardingSpec => gather_dim, logical_process_axes
|
||||
S0S1 => [b, f], (1, 0)
|
||||
S1S0 => [b, f], (0, 1)
|
||||
S01R => [f], (1, 1)
|
||||
RS01 => [b], (1, 1)
|
||||
Example:
|
||||
mesh_shape = (2,4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7]]
|
||||
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
|
||||
S0S1:
|
||||
leading_group_dim = 1
|
||||
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
|
||||
tensor_list = [(0,0),(0,1),(0,2),(0,3),(1,0),(1,1),(1,2),(1,3)] # [(slice_id_f, slice_id_b),...]
|
||||
mesh_shape = (2,4)
|
||||
cat_slice = [4,2]
|
||||
tmp_tensor_list = [(...,shape[f],shape[b]*4,...),(...,shape[f],shape[b]*4,...)]
|
||||
tmp_tensor_list[0] = torch.cat(((0,0),(0,1),(0,2),(0,3)), dim=b)
|
||||
tmp_tensor_list[1] = torch.cat(((1,0),(1,1),(1,2),(1,3)), dim=b)
|
||||
output = torch.cat((tmp_tensor_list[0],tmp_tensor_list[1]), dim=a)
|
||||
S1S0:
|
||||
leading_group_dim = 0
|
||||
process_group = "[0, 4, 1, 5, 2, 6, 3, 7]"
|
||||
tensor_list = [(0,0),(0,1),(1,0),(1,1),(2,0),(2,1),(3,0),(3,1)]
|
||||
mesh_shape = (2,4)
|
||||
cat_slice = [2,4]
|
||||
tmp_tensor_list = [(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...)]
|
||||
tmp_tensor_list[0] = torch.cat(((0,0),(0,1)), dim=b)
|
||||
tmp_tensor_list[1] = torch.cat(((1,0),(1,1)), dim=b)
|
||||
tmp_tensor_list[2] = torch.cat(((2,0),(2,1)), dim=b)
|
||||
tmp_tensor_list[3] = torch.cat(((3,0),(3,1)), dim=b)
|
||||
S10R:
|
||||
leading_group_dim = 0
|
||||
process_group = "[0, 4, 1, 5, 2, 6, 3, 7]"
|
||||
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
|
||||
S01R:
|
||||
leading_group_dim = 1
|
||||
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
|
||||
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
|
||||
'''
|
||||
total_slices = comm_spec.device_mesh.mesh_shape[0]
|
||||
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)]
|
||||
leading_group_dim = comm_spec.logical_process_axes[0]
|
||||
assert len(comm_spec.device_mesh.process_groups_dict) == 1
|
||||
_, process_group = comm_spec.device_mesh.process_groups_dict[0][0]
|
||||
process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim]
|
||||
|
||||
# Global all_gather
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
|
||||
# This is very ugly. I'm figuring out more elegant methods
|
||||
tensor_list_sorted = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)
|
||||
]
|
||||
for i in range(total_slices):
|
||||
tensor_list_sorted[i] = tensor_list[process_number_list[i]]
|
||||
tensor_list = tensor_list_sorted
|
||||
|
||||
if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]:
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous()
|
||||
else:
|
||||
mesh_shape = comm_spec.device_meshes.mesh_shape
|
||||
cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]]
|
||||
tmp_tensor_shape = list(tensor.shape)
|
||||
tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0]
|
||||
tmp_tensor_shape = torch.Size(tmp_tensor_shape)
|
||||
tmp_tensor_list = [
|
||||
torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1])
|
||||
]
|
||||
for i in range(cat_slice[1]):
|
||||
tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]),
|
||||
comm_spec.gather_dim[0]).contiguous()
|
||||
output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _mix_split(tensor, comm_spec):
|
||||
'''
|
||||
Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent)
|
||||
Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split
|
||||
because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension.
|
||||
Assume index of f and b target pairs are 'f' and 'b'
|
||||
S0S1 => [b, f], (1, 0)
|
||||
S1S0 => [b, f], (0, 1)
|
||||
S01R => [f], (0, 0)
|
||||
RS01 => [b], (0, 0)
|
||||
Example:
|
||||
mesh_shape = (2,4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7]]
|
||||
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
|
||||
'''
|
||||
mesh_shape = comm_spec.device_meshes.mesh_shape
|
||||
dim = comm_spec.gather_dim
|
||||
total_slices = comm_spec.device_mesh.mesh_shape[0]
|
||||
|
||||
# Get global rank
|
||||
rank = dist.get_rank()
|
||||
|
||||
leading_group_dim = comm_spec.logical_process_axes[0]
|
||||
process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim]
|
||||
rank = process_number_list.index(rank)
|
||||
|
||||
if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]:
|
||||
length = tensor.shape[dim[0]] // total_slices
|
||||
start = length * rank
|
||||
output = torch.narrow(tensor, dim[0], start, length).contiguous()
|
||||
else:
|
||||
tensor_shape = [tensor.shape[dim[0]], tensor.shape[dim[1]]]
|
||||
rank_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]]
|
||||
length = [tensor_shape[0] // rank_slice[0], tensor_shape[1] // rank_slice[1]]
|
||||
start = [(rank % rank_slice[0]) * length[0], (rank // rank_slice[0]) * length[1]]
|
||||
tmp_output = torch.narrow(tensor, dim[0], start[0], length[0]).contiguous()
|
||||
output = torch.narrow(tmp_output, dim[1], start[1], length[1]).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _ReduceGrad(torch.autograd.Function):
|
||||
"""
|
||||
A customized communication operation which forward is an identity operation,
|
||||
|
@ -204,6 +330,22 @@ class _AllToAll(torch.autograd.Function):
|
|||
return _all_to_all(grad_outputs, ctx.comm_spec), None
|
||||
|
||||
|
||||
class _MixGatherForwardMixSplitBackward(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _mix_gather(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, comm_spec):
|
||||
ctx.comm_spec = comm_spec
|
||||
return _mix_gather(input_, comm_spec)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _mix_split(grad_output, ctx.comm_spec), None
|
||||
|
||||
|
||||
def reduce_grad(input_, comm_spec):
|
||||
return _ReduceGrad.apply(input_, comm_spec)
|
||||
|
||||
|
@ -224,12 +366,17 @@ def all_to_all(input_, comm_spec):
|
|||
return _AllToAll.apply(input_, comm_spec)
|
||||
|
||||
|
||||
def mixgather_forward_split_backward(input_, comm_spec):
|
||||
return _MixGatherForwardMixSplitBackward.apply(input_, comm_spec)
|
||||
|
||||
|
||||
class CollectiveCommPattern(Enum):
|
||||
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
|
||||
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
|
||||
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
|
||||
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
|
||||
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
|
||||
MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd"
|
||||
|
||||
|
||||
class CommSpec:
|
||||
|
@ -255,7 +402,8 @@ class CommSpec:
|
|||
gather_dim=None,
|
||||
shard_dim=None,
|
||||
logical_process_axis=None,
|
||||
forward_only=False):
|
||||
forward_only=False,
|
||||
mix_gather=False):
|
||||
self.comm_pattern = comm_pattern
|
||||
self.sharding_spec = sharding_spec
|
||||
self.gather_dim = gather_dim
|
||||
|
@ -263,8 +411,14 @@ class CommSpec:
|
|||
self.logical_process_axis = logical_process_axis
|
||||
self.forward_only = forward_only
|
||||
if isinstance(self.logical_process_axis, list):
|
||||
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
|
||||
self.logical_process_axis = 0
|
||||
if not mix_gather:
|
||||
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
|
||||
self.logical_process_axis = 0
|
||||
else:
|
||||
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes
|
||||
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
|
||||
# Create a new member `logical_process_axes` to distinguish from original flatten
|
||||
self.logical_process_axes = logical_process_axis
|
||||
else:
|
||||
self.device_mesh = self.sharding_spec.device_mesh
|
||||
|
||||
|
@ -289,6 +443,10 @@ class CommSpec:
|
|||
elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
|
||||
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
|
||||
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
|
||||
elif self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
|
||||
res_list.append(f"comm_pattern:MIXGATHER_FWD_SPLIT_BWD, ")
|
||||
res_list.append(f"gather_dim:{self.gather_dim}, ")
|
||||
res_list.append(f"logical_process_asex:{self.logical_process_axes})")
|
||||
|
||||
return ''.join(res_list)
|
||||
|
||||
|
@ -324,6 +482,11 @@ class CommSpec:
|
|||
forward_communication_cost = 10
|
||||
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
|
||||
|
||||
if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
|
||||
# no need for axis because all devices are used in mix_gather
|
||||
forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size)
|
||||
backward_communication_cost = 10
|
||||
|
||||
if self.forward_only:
|
||||
cost_dict["forward"] = forward_communication_cost
|
||||
cost_dict["backward"] = 0
|
||||
|
@ -356,4 +519,5 @@ pattern_to_func_dict = {
|
|||
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward,
|
||||
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input,
|
||||
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad,
|
||||
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: mixgather_forward_split_backward,
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator
|
||||
|
||||
from .comm_spec import *
|
||||
|
||||
|
@ -328,6 +328,59 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_mix_gather_spec(self, source_spec: ShardingSpec,
|
||||
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
|
||||
'''
|
||||
S0S1 -> RR
|
||||
S1S0 -> RR
|
||||
S01R -> RR
|
||||
RS01 -> RR
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
for f_index in range(tensor_dims - 1):
|
||||
for b_index in range(f_index + 1, tensor_dims):
|
||||
if (f_index not in source_spec.dim_partition_dict) and (b_index not in source_spec.dim_partition_dict):
|
||||
continue
|
||||
else:
|
||||
if f_index in source_spec.dim_partition_dict:
|
||||
# skip (S10, R) -> (R, R)
|
||||
if len(f_target_pair[1]) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]:
|
||||
continue
|
||||
f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
|
||||
else:
|
||||
f_target_pair = (f_index, [])
|
||||
if b_index in source_spec.dim_partition_dict:
|
||||
# skip (R, S10) -> (R, R)
|
||||
if len(b_target_pair[1]) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]:
|
||||
continue
|
||||
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
|
||||
else:
|
||||
b_target_pair = (b_index, [])
|
||||
|
||||
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
||||
comm_spec = CommSpec(comm_pathern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=self.forward_only,
|
||||
mix_gather=True)
|
||||
cost_dict = comm_spec.get_comm_cost()
|
||||
new_dim_partition_dict = {}
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
except ShardingSpecException:
|
||||
pass
|
||||
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with one step transform, and
|
||||
|
|
|
@ -90,6 +90,31 @@ def shard_simulator(target_pair, legal_sharding_dims):
|
|||
return shard_list_list
|
||||
|
||||
|
||||
def mix_gather_simulator(f_target_pair, b_target_pair):
|
||||
'''
|
||||
Assume index of f and b target pairs are 'f' and 'b'
|
||||
S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0)
|
||||
S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1)
|
||||
S01R => Input: (f, [0, 1]), (b, []) Output: [f], (1, 1)
|
||||
RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1)
|
||||
S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0)
|
||||
RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0)
|
||||
'''
|
||||
if f_target_pair[1] and b_target_pair[1]:
|
||||
leading_dim = b_target_pair[1] > f_target_pair[1]
|
||||
return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)]
|
||||
if f_target_pair[1]:
|
||||
leading_dim = f_target_pair[1][0] < f_target_pair[1][1]
|
||||
return [
|
||||
f_target_pair[0],
|
||||
], [int(leading_dim), int(leading_dim)]
|
||||
if b_target_pair[1]:
|
||||
leading_dim = b_target_pair[1][0] < b_target_pair[1][1]
|
||||
return [
|
||||
b_target_pair[0],
|
||||
], [int(leading_dim), int(leading_dim)]
|
||||
|
||||
|
||||
# The function is credited to PyTorch Team
|
||||
def named_params_with_colotensor(
|
||||
module: nn.Module,
|
||||
|
|
|
@ -0,0 +1,333 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor.utils import mix_gather_simulator
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
def check_mix_gather_S0S1(device_mesh, rank):
|
||||
tensor_to_check = torch.arange(64).reshape((8, 8)).cuda()
|
||||
(f, b) = (0, 1)
|
||||
f_target_pair = (f, [0])
|
||||
b_target_pair = (b, [1])
|
||||
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
||||
tensor_slice = [4, 2] # (4, 2)
|
||||
rank_slice = 4
|
||||
f_start = (rank // rank_slice) * tensor_slice[0]
|
||||
b_start = (rank % rank_slice) * tensor_slice[1]
|
||||
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
|
||||
b_start:b_start + tensor_slice[1]].contiguous().cuda()
|
||||
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1
|
||||
# device_mesh_shape: (2, 4)
|
||||
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_two_all_gather_S0S1(device_mesh, rank):
|
||||
tensor_width = 8
|
||||
tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda()
|
||||
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2)
|
||||
rank_slice = 4
|
||||
f_start = (rank // rank_slice) * tensor_slice[0]
|
||||
b_start = (rank % rank_slice) * tensor_slice[1]
|
||||
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
|
||||
b_start:b_start + tensor_slice[1]].contiguous().cuda()
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1
|
||||
# device_mesh_shape: (2, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=0,
|
||||
logical_process_axis=0)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
dim_partition_dict = {1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: R,S1
|
||||
# device_mesh_shape: (2, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=1)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_mix_gather_S1S0(device_mesh, rank):
|
||||
tensor_to_check = torch.arange(64).reshape((8, 8)).cuda()
|
||||
(f, b) = (0, 1)
|
||||
f_target_pair = (f, [1])
|
||||
b_target_pair = (b, [0])
|
||||
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
||||
tensor_slice = [2, 4]
|
||||
rank_slice = 4
|
||||
f_start = (rank % rank_slice) * tensor_slice[0]
|
||||
b_start = (rank // rank_slice) * tensor_slice[1]
|
||||
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
|
||||
b_start:b_start + tensor_slice[1]].contiguous().cuda()
|
||||
|
||||
dim_partition_dict = {0: [1], 1: [0]}
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S1,S0
|
||||
# device_mesh_shape: (2, 4)
|
||||
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_two_all_gather_S1S0(device_mesh, rank):
|
||||
tensor_width = 8
|
||||
tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda()
|
||||
|
||||
tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2)
|
||||
rank_slice = 4
|
||||
f_start = (rank % rank_slice) * tensor_slice[0]
|
||||
b_start = (rank // rank_slice) * tensor_slice[1]
|
||||
tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0],
|
||||
b_start:b_start + tensor_slice[1]].contiguous().cuda()
|
||||
|
||||
dim_partition_dict = {0: [1], 1: [0]}
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S1,S0
|
||||
# device_mesh_shape: (2, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=0,
|
||||
logical_process_axis=1)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
dim_partition_dict = {1: [0]}
|
||||
# DistSpec:
|
||||
# shard_sequence: R,S0
|
||||
# device_mesh_shape: (2, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=0)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_mix_gather_S01R(device_mesh, rank):
|
||||
tensor_to_check = torch.arange(64).reshape((8, 8)).cuda()
|
||||
(f, b) = (0, 1)
|
||||
f_target_pair = (f, [0, 1])
|
||||
b_target_pair = (b, [])
|
||||
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
||||
tensor_to_comm = tensor_to_check[rank:rank + 1, :].contiguous().cuda()
|
||||
|
||||
dim_partition_dict = {0: [0, 1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R
|
||||
# device_mesh_shape: (2, 4)
|
||||
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_two_all_gather_S01R(device_mesh, rank):
|
||||
tensor_width = 8
|
||||
tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda()
|
||||
|
||||
rank_stride = tensor_width // 8
|
||||
tensor_to_comm = tensor_to_check[rank:rank + rank_stride, :].contiguous().cuda()
|
||||
|
||||
dim_partition_dict = {0: [0, 1]}
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S01, R
|
||||
# device_mesh_shape: (2, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=0,
|
||||
logical_process_axis=1)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
dim_partition_dict = {0: [0]}
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S1, R
|
||||
# device_mesh_shape: (2, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=0,
|
||||
logical_process_axis=0)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_mix_gather_RS01(device_mesh, rank):
|
||||
tensor_to_check = torch.arange(64).reshape((8, 8)).cuda()
|
||||
|
||||
(f, b) = (0, 1)
|
||||
f_target_pair = (f, [])
|
||||
b_target_pair = (b, [0, 1])
|
||||
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
||||
tensor_to_comm = tensor_to_check[:, rank:rank + 1].contiguous().cuda()
|
||||
|
||||
dim_partition_dict = {1: [0, 1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: R, S01
|
||||
# device_mesh_shape: (2, 4)
|
||||
source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=True,
|
||||
mix_gather=True)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_two_all_gather_RS01(device_mesh, rank):
|
||||
tensor_width = 8
|
||||
tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda()
|
||||
|
||||
rank_stride = tensor_width // 8
|
||||
tensor_to_comm = tensor_to_check[:, rank:rank + rank_stride].contiguous().cuda()
|
||||
|
||||
dim_partition_dict = {1: [0, 1]}
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: R, S01
|
||||
# device_mesh_shape: (2, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=1)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
dim_partition_dict = {1: [0]}
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: R, S1
|
||||
# device_mesh_shape: (2, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=0)
|
||||
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_comm(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
physical_mesh_id = torch.arange(0, 8)
|
||||
assert rank == gpc.get_global_rank()
|
||||
|
||||
mesh_shape = (2, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True, need_flatten=True)
|
||||
|
||||
check_mix_gather_S0S1(device_mesh, rank)
|
||||
|
||||
check_two_all_gather_S0S1(device_mesh, rank)
|
||||
|
||||
check_mix_gather_S1S0(device_mesh, rank)
|
||||
|
||||
check_two_all_gather_S1S0(device_mesh, rank)
|
||||
|
||||
check_mix_gather_S01R(device_mesh, rank)
|
||||
|
||||
check_two_all_gather_S01R(device_mesh, rank)
|
||||
|
||||
check_mix_gather_RS01(device_mesh, rank)
|
||||
|
||||
check_two_all_gather_RS01(device_mesh, rank)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skip because the check functions assume 8 GPUS but CI only have 4 GPUs")
|
||||
def test_mix_gather():
|
||||
world_size = 8
|
||||
run_func = partial(check_comm, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mix_gather()
|
Loading…
Reference in New Issue