From 4b03c25f85cd2dbe91a07c8febff3b2452ea39c2 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 25 Aug 2022 16:48:12 +0800 Subject: [PATCH] [tensor]add 1D device mesh (#1492) --- colossalai/device/device_mesh.py | 27 ++++++++++++++++-- colossalai/tensor/shape_consistency.py | 23 +++++++++------ tests/test_tensor/test_comm_spec_apply.py | 28 ++++++++++++++++++- .../test_shape_consistency_apply.py | 1 - 4 files changed, 66 insertions(+), 13 deletions(-) diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index ee9380603..df010e7d7 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -25,7 +25,13 @@ class DeviceMesh: (default: False) """ - def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None, init_process_group=False): + def __init__(self, + physical_mesh_id, + mesh_shape, + mesh_alpha=None, + mesh_beta=None, + init_process_group=False, + need_flatten=True): self.physical_mesh_id = physical_mesh_id self.mesh_shape = mesh_shape self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) @@ -39,8 +45,12 @@ class DeviceMesh: mesh_beta = [1] * len(self.mesh_shape) self.mesh_alpha = tuple(mesh_alpha) self.mesh_beta = tuple(mesh_beta) - if init_process_group: + self.init_process_group = init_process_group + self.need_flatten = need_flatten + if self.init_process_group: self.process_groups_dict = self.create_process_groups_for_logical_mesh() + if self.need_flatten: + self.flatten_device_mesh = self.flatten() @property def shape(self): @@ -54,6 +64,19 @@ class DeviceMesh: def logical_mesh_id(self): return self._logical_mesh_id + def flatten(self): + """ + Flatten the logical mesh into an effective 1d logical mesh, + """ + flatten_mesh_shape_size = len(self.mesh_shape) + flatten_mesh_shape = [self.num_devices] + return DeviceMesh(self.physical_mesh_id, + tuple(flatten_mesh_shape), + mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), + mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + init_process_group=self.init_process_group, + need_flatten=False) + def _global_rank_to_logical_rank_map(self, tensor, index_list): ''' This method is a helper function to build convert_map recursively. diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 612290e4e..5e7ec68f3 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -3,6 +3,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from enum import Enum from copy import deepcopy +from typing import Dict, List, Optional, Tuple, Union import torch.distributed as dist import math from functools import reduce @@ -29,9 +30,9 @@ class CommSpec: Argument: comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action. - gather_dim(int, optional): The gather_dim of the tensor will be gathered. - shard_dim(int, optional): The shard_dim of the tensor will be sharded. - logical_process_axis(int, optional): The mesh_dim to implement the communication action. + gather_dim(int, Optional): The gather_dim of the tensor will be gathered. + shard_dim(int, Optional): The shard_dim of the tensor will be sharded. + logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. ''' def __init__(self, comm_pattern, sharding_spec, gather_dim=None, shard_dim=None, logical_process_axis=None): @@ -40,6 +41,11 @@ class CommSpec: self.gather_dim = gather_dim self.shard_dim = shard_dim self.logical_process_axis = logical_process_axis + if isinstance(self.logical_process_axis, list): + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.logical_process_axis = 0 + else: + self.device_mesh = self.sharding_spec.device_mesh def __repr__(self): res_list = ["CommSpec:("] @@ -70,11 +76,11 @@ class CommSpec: ''' comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) if self.comm_pattern == CollectiveCommPattern.ALLGATHER: - return self.sharding_spec.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) + return self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) if self.comm_pattern == CollectiveCommPattern.ALLTOALL: - return self.sharding_spec.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) + return self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) if self.comm_pattern == CollectiveCommPattern.ALLREDUCE: - return self.sharding_spec.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) + return self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) if self.comm_pattern == CollectiveCommPattern.SHARD: return 0 raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.") @@ -87,15 +93,14 @@ class CommSpec: Argument: tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. ''' - device_mesh = self.sharding_spec.device_mesh - process_groups_list = device_mesh.process_groups_dict[self.logical_process_axis] + process_groups_list = self.device_mesh.process_groups_dict[self.logical_process_axis] if self.comm_pattern == CollectiveCommPattern.ALLGATHER: for rank_list, process_group in process_groups_list: if dist.get_rank() in rank_list: tensor_list = [ torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) - for _ in range(self.sharding_spec.device_mesh.mesh_shape[self.logical_process_axis]) + for _ in range(self.device_mesh.mesh_shape[self.logical_process_axis]) ] tensor = tensor group = process_group diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py index 4bc35c782..fd843f058 100644 --- a/tests/test_tensor/test_comm_spec_apply.py +++ b/tests/test_tensor/test_comm_spec_apply.py @@ -133,13 +133,36 @@ def check_all_reduce(device_mesh, rank): # device_mesh_shape: (2, 2) sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) - # CommSpec:CommSpec:(comm_pattern:all_reduce, logical_process_axis:0) + # CommSpec:(comm_pattern:all_reduce, logical_process_axis:0) comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE, sharding_spec, logical_process_axis=0) comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) +def check_all_reduce_in_flatten_device_mesh(device_mesh, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + # reduce through logical process axis 0 at flatten device mesh + # tensor to check + # tensor([[6., 6.], + # [6., 6.]]) + tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() + + dim_partition_dict = {} + # DistSpec: + # shard_sequence: R,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) + comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE, sharding_spec, logical_process_axis=[0, 1]) + comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + def check_comm(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -162,6 +185,9 @@ def check_comm(rank, world_size, port): # test all reduce check_all_reduce(device_mesh, rank) + + # test all reduce in 1D flatten device mesh + check_all_reduce_in_flatten_device_mesh(device_mesh, rank) gpc.destroy() diff --git a/tests/test_tensor/test_shape_consistency_apply.py b/tests/test_tensor/test_shape_consistency_apply.py index 66880bac3..843203916 100644 --- a/tests/test_tensor/test_shape_consistency_apply.py +++ b/tests/test_tensor/test_shape_consistency_apply.py @@ -64,7 +64,6 @@ def check_apply(rank, world_size, port): tensor_to_comm.sharding_spec = sharding_spec_source shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target) - print(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)