[tensor]add 1D device mesh (#1492)

pull/1487/head
YuliangLiu0306 2022-08-25 16:48:12 +08:00 committed by GitHub
parent b8d0e39eaf
commit 4b03c25f85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 13 deletions

View File

@ -25,7 +25,13 @@ class DeviceMesh:
(default: False) (default: False)
""" """
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None, init_process_group=False): def __init__(self,
physical_mesh_id,
mesh_shape,
mesh_alpha=None,
mesh_beta=None,
init_process_group=False,
need_flatten=True):
self.physical_mesh_id = physical_mesh_id self.physical_mesh_id = physical_mesh_id
self.mesh_shape = mesh_shape self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
@ -39,8 +45,12 @@ class DeviceMesh:
mesh_beta = [1] * len(self.mesh_shape) mesh_beta = [1] * len(self.mesh_shape)
self.mesh_alpha = tuple(mesh_alpha) self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta) self.mesh_beta = tuple(mesh_beta)
if init_process_group: self.init_process_group = init_process_group
self.need_flatten = need_flatten
if self.init_process_group:
self.process_groups_dict = self.create_process_groups_for_logical_mesh() self.process_groups_dict = self.create_process_groups_for_logical_mesh()
if self.need_flatten:
self.flatten_device_mesh = self.flatten()
@property @property
def shape(self): def shape(self):
@ -54,6 +64,19 @@ class DeviceMesh:
def logical_mesh_id(self): def logical_mesh_id(self):
return self._logical_mesh_id return self._logical_mesh_id
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
"""
flatten_mesh_shape_size = len(self.mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self.physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self.init_process_group,
need_flatten=False)
def _global_rank_to_logical_rank_map(self, tensor, index_list): def _global_rank_to_logical_rank_map(self, tensor, index_list):
''' '''
This method is a helper function to build convert_map recursively. This method is a helper function to build convert_map recursively.

View File

@ -3,6 +3,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from enum import Enum from enum import Enum
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union
import torch.distributed as dist import torch.distributed as dist
import math import math
from functools import reduce from functools import reduce
@ -29,9 +30,9 @@ class CommSpec:
Argument: Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action. sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action.
gather_dim(int, optional): The gather_dim of the tensor will be gathered. gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, optional): The shard_dim of the tensor will be sharded. shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(int, optional): The mesh_dim to implement the communication action. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
''' '''
def __init__(self, comm_pattern, sharding_spec, gather_dim=None, shard_dim=None, logical_process_axis=None): def __init__(self, comm_pattern, sharding_spec, gather_dim=None, shard_dim=None, logical_process_axis=None):
@ -40,6 +41,11 @@ class CommSpec:
self.gather_dim = gather_dim self.gather_dim = gather_dim
self.shard_dim = shard_dim self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis self.logical_process_axis = logical_process_axis
if isinstance(self.logical_process_axis, list):
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.logical_process_axis = 0
else:
self.device_mesh = self.sharding_spec.device_mesh
def __repr__(self): def __repr__(self):
res_list = ["CommSpec:("] res_list = ["CommSpec:("]
@ -70,11 +76,11 @@ class CommSpec:
''' '''
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
if self.comm_pattern == CollectiveCommPattern.ALLGATHER: if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
return self.sharding_spec.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) return self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLTOALL: if self.comm_pattern == CollectiveCommPattern.ALLTOALL:
return self.sharding_spec.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) return self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLREDUCE: if self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
return self.sharding_spec.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) return self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.SHARD: if self.comm_pattern == CollectiveCommPattern.SHARD:
return 0 return 0
raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.") raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.")
@ -87,15 +93,14 @@ class CommSpec:
Argument: Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
''' '''
device_mesh = self.sharding_spec.device_mesh process_groups_list = self.device_mesh.process_groups_dict[self.logical_process_axis]
process_groups_list = device_mesh.process_groups_dict[self.logical_process_axis]
if self.comm_pattern == CollectiveCommPattern.ALLGATHER: if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
for rank_list, process_group in process_groups_list: for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list: if dist.get_rank() in rank_list:
tensor_list = [ tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(self.sharding_spec.device_mesh.mesh_shape[self.logical_process_axis]) for _ in range(self.device_mesh.mesh_shape[self.logical_process_axis])
] ]
tensor = tensor tensor = tensor
group = process_group group = process_group

View File

@ -133,13 +133,36 @@ def check_all_reduce(device_mesh, rank):
# device_mesh_shape: (2, 2) # device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:CommSpec:(comm_pattern:all_reduce, logical_process_axis:0) # CommSpec:(comm_pattern:all_reduce, logical_process_axis:0)
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE, sharding_spec, logical_process_axis=0) comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE, sharding_spec, logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm) comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check) assert tensor_to_comm.equal(tensor_to_check)
def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
# tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank
# reduce through logical process axis 0 at flatten device mesh
# tensor to check
# tensor([[6., 6.],
# [6., 6.]])
tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
dim_partition_dict = {}
# DistSpec:
# shard_sequence: R,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE, sharding_spec, logical_process_axis=[0, 1])
comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_comm(rank, world_size, port): def check_comm(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@ -162,6 +185,9 @@ def check_comm(rank, world_size, port):
# test all reduce # test all reduce
check_all_reduce(device_mesh, rank) check_all_reduce(device_mesh, rank)
# test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh(device_mesh, rank)
gpc.destroy() gpc.destroy()

View File

@ -64,7 +64,6 @@ def check_apply(rank, world_size, port):
tensor_to_comm.sharding_spec = sharding_spec_source tensor_to_comm.sharding_spec = sharding_spec_source
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target) shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
print(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check) assert tensor_to_comm.equal(tensor_to_check)
assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence) assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)