[tensor] support runtime ShardingSpec apply (#1453)

* [tensor] support runtime ShardingSpec apply

* polish code

* polish code
pull/1469/head
YuliangLiu0306 2022-08-19 13:39:51 +08:00 committed by GitHub
parent 177d3f5718
commit b73fb7a077
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 485 additions and 11 deletions

View File

@ -1,6 +1,7 @@
from functools import reduce
import operator
import torch
import torch.distributed as dist
class DeviceMesh:
@ -18,9 +19,13 @@ class DeviceMesh:
communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing
communication cost (default: None)
init_process_group (bool, optional): initialize logical process group
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
"""
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None, init_process_group=False):
self.physical_mesh_id = physical_mesh_id
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
@ -34,6 +39,8 @@ class DeviceMesh:
mesh_beta = [1] * len(self.mesh_shape)
self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta)
if init_process_group:
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
@property
def shape(self):
@ -57,6 +64,28 @@ class DeviceMesh:
else:
self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
def create_process_groups_for_logical_mesh(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
process_groups_dict = {}
check_duplicate_list = []
global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
for axis, process_group in process_groups.items():
if axis not in process_groups_dict:
process_groups_dict[axis] = []
if process_group not in check_duplicate_list:
check_duplicate_list.append(process_group)
process_group_handler = dist.new_group(process_group)
process_groups_dict[axis].append((process_group, process_group_handler))
return process_groups_dict
def global_rank_to_logical_rank(self, rank):
return self.convert_map[rank]

View File

@ -3,15 +3,18 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from enum import Enum
from copy import deepcopy
import torch.distributed as dist
import math
from functools import reduce
import operator
from torch.distributed import ReduceOp
class CollectiveCommPattern(Enum):
ALLGATHER = 'all_gather'
ALLTOALL = 'all_to_all'
SHARD = 'shard'
ALLREDUCE = 'all_reduce'
class CommSpec:
@ -41,7 +44,7 @@ class CommSpec:
def __repr__(self):
res_list = ["CommSpec:("]
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
res_list.append(f"comm_pattern:allgather, ")
res_list.append(f"comm_pattern:all_gather, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLTOALL:
@ -49,15 +52,19 @@ class CommSpec:
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis: {self.logical_process_axis})")
else:
elif self.comm_pattern == CollectiveCommPattern.SHARD:
res_list.append(f"comm_pattern:shard, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
res_list.append(f"comm_pattern:all_reduce, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
return ''.join(res_list)
def get_comm_cost(self):
'''
For all_gather and all2all operation, the formula provided in DeviceMesh with alpha-beta model is used to
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
@ -66,10 +73,77 @@ class CommSpec:
return self.sharding_spec.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLTOALL:
return self.sharding_spec.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
return 0
if self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
return self.sharding_spec.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.SHARD:
return 0
raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.")
def covert_spec_to_action(self):
pass
def covert_spec_to_action(self, tensor):
'''
Convert CommSpec into runtime action, implement real collection communication to target tensor.
The collection communication action is directed by the CommSpec.
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
device_mesh = self.sharding_spec.device_mesh
process_groups_list = device_mesh.process_groups_dict[self.logical_process_axis]
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(self.sharding_spec.device_mesh.mesh_shape[self.logical_process_axis])
]
tensor = tensor
group = process_group
dist.all_gather(tensor_list, tensor, group=group)
tensor.data = torch.cat(tuple(tensor_list), self.gather_dim)
elif self.comm_pattern == CollectiveCommPattern.SHARD:
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor = tensor
dim = self.shard_dim
length = tensor.shape[self.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
tensor.data = torch.narrow(tensor, dim, start, length)
elif self.comm_pattern == CollectiveCommPattern.ALLTOALL:
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
new_shape = list(tensor.shape)
new_shape[self.shard_dim] = new_shape[self.shard_dim] // len(rank_list)
new_shape = torch.Size(new_shape)
output_tensor_list = [
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
dim = self.shard_dim
length = tensor.shape[self.shard_dim] // len(rank_list)
input_tensor_list = [
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
tensor.data = torch.cat(tuple(output_tensor_list), self.gather_dim)
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
# For the consistency of collective communication operation, we temporally do not
# allow all_reduce two different mesh dimensions in the same time.
# e.g.: MatMul[(R, S01), (S01, R)] -> Partial(R, R),
# all_reduce(Partial, logical_pg=(0, 1)) is NOT allowed, instead
# we need to do this in two steps:
# 1. all_reduce(Partial, logical_pg=1)
# 2. all_reduce(Partial, logical_pg=0)
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
tensor.data = tensor
else:
tensor.data = tensor
class ShapeConsistencyManager:
@ -191,7 +265,7 @@ class ShapeConsistencyManager:
else:
f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict:
# skip (R, R) -> (R, S01) is NOT allowed
# skip (R, S01) -> (S01, R) is NOT allowed
if len(source_spec.dim_partition_dict[b_index]) >= 2:
continue
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
@ -409,7 +483,7 @@ class ShapeConsistencyManager:
self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost)
temp_sharding_spec = deepcopy(source_spec)
temp_sharding_spec = source_spec
transform_path.append(temp_sharding_spec)
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
while total_steps <= MAX_TRANSFORM_STEPS:
@ -428,9 +502,9 @@ class ShapeConsistencyManager:
return (transform_path, comm_action_sequence, total_cost)
if spec_difference < best_difference_score:
temp_sharding_spec = deepcopy(sharding_spec)
temp_sharding_spec = sharding_spec
temp_cost = cost
temp_comm_spec = deepcopy(comm_spec)
temp_comm_spec = comm_spec
best_difference_score = spec_difference
transform_path.append(temp_sharding_spec)
@ -439,3 +513,67 @@ class ShapeConsistencyManager:
total_steps += 1
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
def apply(self, tensor_with_sharding_spec, target_spec):
'''
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
shape_consistency method.
Argument:
tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec.
target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec.
Example:
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = torch.Size((4, 2))
shape_consistency_manager = ShapeConsistencyManager()
dim_partition_source = {0: [0]}
dim_partition_target = {1: [0]}
# DistSpec:
# shard_sequence: S0,R
# device_mesh_shape: (2, 2)
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
# DistSpec:
# shard_sequence: R,S0
# device_mesh_shape: (2, 2)
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)
sharded_tensor_1 = torch.ones(2, 1)
# tensor([[0., 1.],
# [0., 1.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (2, 3):
sharded_tensor_0 = torch.ones(2, 1) * 2
sharded_tensor_1 = torch.ones(2, 1) * 3
# tensor([[2., 3.],
# [2., 3.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
tensor_to_comm.sharding_spec = sharding_spec_source
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
print(tensor_to_comm)
Output in rank0 and rank2:
tensor([[0.],
[0.],
[2.],
[2.]])
Output in rank1 and rank3:
tensor([[1.],
[1.],
[3.],
[3.]])
'''
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
for comm_spec in comm_action_sequence:
comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
tensor_with_sharding_spec.sharding_spec = target_spec

View File

@ -0,0 +1,49 @@
import torch
from functools import partial
import pytest
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.device.device_mesh import DeviceMesh
def check_layer(rank, world_size, port):
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
assert rank == gpc.get_global_rank()
tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda()
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
logical_process_groups = device_mesh.process_groups_dict
for mesh_dim, pgs in logical_pg_dict.items():
for index, pg in enumerate(pgs):
if rank in pg:
tensor = torch.ones(4).cuda()
group = logical_process_groups[mesh_dim][index][1]
dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
assert tensor.equal(tensor_to_check)
gpc.destroy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_logical_pg():
world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_logical_pg()

View File

@ -0,0 +1,177 @@
import torch
from functools import partial
import pytest
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CommSpec, CollectiveCommPattern
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.sharding_spec import ShardingSpec
def check_all_gather(device_mesh, rank):
# tensor to comm
if rank in (0, 2):
sharded_tensor_to_comm = torch.ones(2, 2).cuda()
else:
sharded_tensor_to_comm = torch.zeros(2, 2).cuda()
# tensor to check
tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda()
# test all gather
dim_partition_dict = {1: [1]}
# DistSpec:
# shard_sequence: R,S1
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.ALLGATHER, sharding_spec, gather_dim=1, logical_process_axis=1)
comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
assert sharded_tensor_to_comm.equal(tensor_to_check)
def check_shard(device_mesh, rank):
# tensor to comm
sharded_tensor_to_comm_0 = torch.zeros(2, 2).cuda()
sharded_tensor_to_comm_1 = torch.ones(2, 2).cuda()
# tensor([[0., 0., 1., 1.],
# [0., 0., 1., 1.]])
tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1)
# test shard
dim_partition_dict = {}
# DistSpec:
# shard_sequence: R,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_shard.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.SHARD, sharding_spec, shard_dim=1, logical_process_axis=1)
comm_spec.covert_spec_to_action(tensor_to_shard)
if rank in (0, 2):
assert tensor_to_shard.equal(sharded_tensor_to_comm_0)
if rank in (1, 3):
assert tensor_to_shard.equal(sharded_tensor_to_comm_1)
def check_all_to_all(device_mesh, rank):
# tensor to comm
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)
sharded_tensor_1 = torch.ones(2, 1)
# tensor([[0., 1.],
# [0., 1.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (2, 3):
sharded_tensor_0 = torch.ones(2, 1) * 2
sharded_tensor_1 = torch.ones(2, 1) * 3
# tensor([[2., 3.],
# [2., 3.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (0, 1):
# tensor([[0.],
# [0.],
# [2.],
# [2.]])
tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda()
if rank in (2, 3):
# tensor([[1.],
# [1.],
# [3.],
# [3.]])
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
# test shard
dim_partition_dict = {0: [0]}
# DistSpec:
# shard_sequence: S0,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.ALLTOALL,
sharding_spec,
gather_dim=0,
shard_dim=1,
logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_all_reduce(device_mesh, rank):
# tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank
# reduce through logical process axis 0
# tensor to check
if rank in (0, 2):
# tensor([[2., 2.],
# [2., 2.]])
tensor_to_check = torch.tensor([[2, 2], [2, 2]], dtype=tensor_to_comm.dtype).cuda()
if rank in (1, 3):
# tensor([[4., 4.],
# [4., 4.]])
tensor_to_check = torch.tensor([[4, 4], [4, 4]], dtype=tensor_to_comm.dtype).cuda()
dim_partition_dict = {}
# DistSpec:
# shard_sequence: R,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:CommSpec:(comm_pattern:all_reduce, logical_process_axis:0)
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE, sharding_spec, logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_comm(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
assert rank == gpc.get_global_rank()
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# test all gather
check_all_gather(device_mesh, rank)
# test shard
check_shard(device_mesh, rank)
# test all to all
check_all_to_all(device_mesh, rank)
# test all reduce
check_all_reduce(device_mesh, rank)
gpc.destroy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_comm_spec():
world_size = 4
run_func = partial(check_comm, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_comm_spec()

View File

@ -0,0 +1,81 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
def check_apply(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = torch.Size((4, 2))
shape_consistency_manager = ShapeConsistencyManager()
dim_partition_source = {0: [0]}
dim_partition_target = {1: [0]}
# DistSpec:
# shard_sequence: S0,R
# device_mesh_shape: (2, 2)
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
# DistSpec:
# shard_sequence: R,S0
# device_mesh_shape: (2, 2)
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)
sharded_tensor_1 = torch.ones(2, 1)
# tensor([[0., 1.],
# [0., 1.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (2, 3):
sharded_tensor_0 = torch.ones(2, 1) * 2
sharded_tensor_1 = torch.ones(2, 1) * 3
# tensor([[2., 3.],
# [2., 3.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (0, 1):
# tensor([[0.],
# [0.],
# [2.],
# [2.]])
tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda()
if rank in (2, 3):
# tensor([[1.],
# [1.],
# [3.],
# [3.]])
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
tensor_to_comm.sharding_spec = sharding_spec_source
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
print(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_apply():
world_size = 4
run_func = partial(check_apply, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_apply()