[tensor] support runtime ShardingSpec apply (#1453)

* [tensor] support runtime ShardingSpec apply

* polish code

* polish code
pull/1469/head
YuliangLiu0306 2022-08-19 13:39:51 +08:00 committed by GitHub
parent 177d3f5718
commit b73fb7a077
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 485 additions and 11 deletions

View File

@ -1,6 +1,7 @@
from functools import reduce from functools import reduce
import operator import operator
import torch import torch
import torch.distributed as dist
class DeviceMesh: class DeviceMesh:
@ -18,9 +19,13 @@ class DeviceMesh:
communication cost (default: None) communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing mesh_beta (List[float], optional): coefficients used for computing
communication cost (default: None) communication cost (default: None)
init_process_group (bool, optional): initialize logical process group
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
""" """
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None): def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None, init_process_group=False):
self.physical_mesh_id = physical_mesh_id self.physical_mesh_id = physical_mesh_id
self.mesh_shape = mesh_shape self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
@ -34,6 +39,8 @@ class DeviceMesh:
mesh_beta = [1] * len(self.mesh_shape) mesh_beta = [1] * len(self.mesh_shape)
self.mesh_alpha = tuple(mesh_alpha) self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta) self.mesh_beta = tuple(mesh_beta)
if init_process_group:
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
@property @property
def shape(self): def shape(self):
@ -57,6 +64,28 @@ class DeviceMesh:
else: else:
self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index]) self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
def create_process_groups_for_logical_mesh(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
process_groups_dict = {}
check_duplicate_list = []
global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
for axis, process_group in process_groups.items():
if axis not in process_groups_dict:
process_groups_dict[axis] = []
if process_group not in check_duplicate_list:
check_duplicate_list.append(process_group)
process_group_handler = dist.new_group(process_group)
process_groups_dict[axis].append((process_group, process_group_handler))
return process_groups_dict
def global_rank_to_logical_rank(self, rank): def global_rank_to_logical_rank(self, rank):
return self.convert_map[rank] return self.convert_map[rank]

View File

@ -3,15 +3,18 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from enum import Enum from enum import Enum
from copy import deepcopy from copy import deepcopy
import torch.distributed as dist
import math import math
from functools import reduce from functools import reduce
import operator import operator
from torch.distributed import ReduceOp
class CollectiveCommPattern(Enum): class CollectiveCommPattern(Enum):
ALLGATHER = 'all_gather' ALLGATHER = 'all_gather'
ALLTOALL = 'all_to_all' ALLTOALL = 'all_to_all'
SHARD = 'shard' SHARD = 'shard'
ALLREDUCE = 'all_reduce'
class CommSpec: class CommSpec:
@ -41,7 +44,7 @@ class CommSpec:
def __repr__(self): def __repr__(self):
res_list = ["CommSpec:("] res_list = ["CommSpec:("]
if self.comm_pattern == CollectiveCommPattern.ALLGATHER: if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
res_list.append(f"comm_pattern:allgather, ") res_list.append(f"comm_pattern:all_gather, ")
res_list.append(f"gather_dim:{self.gather_dim}, ") res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})") res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLTOALL: elif self.comm_pattern == CollectiveCommPattern.ALLTOALL:
@ -49,15 +52,19 @@ class CommSpec:
res_list.append(f"gather_dim:{self.gather_dim}, ") res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"shard_dim:{self.shard_dim}, ") res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis: {self.logical_process_axis})") res_list.append(f"logical_process_axis: {self.logical_process_axis})")
else: elif self.comm_pattern == CollectiveCommPattern.SHARD:
res_list.append(f"comm_pattern:shard, ") res_list.append(f"comm_pattern:shard, ")
res_list.append(f"shard_dim:{self.shard_dim}, ") res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})") res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
res_list.append(f"comm_pattern:all_reduce, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
return ''.join(res_list) return ''.join(res_list)
def get_comm_cost(self): def get_comm_cost(self):
''' '''
For all_gather and all2all operation, the formula provided in DeviceMesh with alpha-beta model is used to For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost. compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero. For shard operation, it is an on-chip operation, so the communication cost is zero.
''' '''
@ -66,10 +73,77 @@ class CommSpec:
return self.sharding_spec.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) return self.sharding_spec.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLTOALL: if self.comm_pattern == CollectiveCommPattern.ALLTOALL:
return self.sharding_spec.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) return self.sharding_spec.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
return 0 if self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
return self.sharding_spec.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.SHARD:
return 0
raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.")
def covert_spec_to_action(self): def covert_spec_to_action(self, tensor):
pass '''
Convert CommSpec into runtime action, implement real collection communication to target tensor.
The collection communication action is directed by the CommSpec.
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
device_mesh = self.sharding_spec.device_mesh
process_groups_list = device_mesh.process_groups_dict[self.logical_process_axis]
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(self.sharding_spec.device_mesh.mesh_shape[self.logical_process_axis])
]
tensor = tensor
group = process_group
dist.all_gather(tensor_list, tensor, group=group)
tensor.data = torch.cat(tuple(tensor_list), self.gather_dim)
elif self.comm_pattern == CollectiveCommPattern.SHARD:
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor = tensor
dim = self.shard_dim
length = tensor.shape[self.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
tensor.data = torch.narrow(tensor, dim, start, length)
elif self.comm_pattern == CollectiveCommPattern.ALLTOALL:
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
new_shape = list(tensor.shape)
new_shape[self.shard_dim] = new_shape[self.shard_dim] // len(rank_list)
new_shape = torch.Size(new_shape)
output_tensor_list = [
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
dim = self.shard_dim
length = tensor.shape[self.shard_dim] // len(rank_list)
input_tensor_list = [
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
tensor.data = torch.cat(tuple(output_tensor_list), self.gather_dim)
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
# For the consistency of collective communication operation, we temporally do not
# allow all_reduce two different mesh dimensions in the same time.
# e.g.: MatMul[(R, S01), (S01, R)] -> Partial(R, R),
# all_reduce(Partial, logical_pg=(0, 1)) is NOT allowed, instead
# we need to do this in two steps:
# 1. all_reduce(Partial, logical_pg=1)
# 2. all_reduce(Partial, logical_pg=0)
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
tensor.data = tensor
else:
tensor.data = tensor
class ShapeConsistencyManager: class ShapeConsistencyManager:
@ -191,7 +265,7 @@ class ShapeConsistencyManager:
else: else:
f_target_pair = (f_index, []) f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict: if b_index in source_spec.dim_partition_dict:
# skip (R, R) -> (R, S01) is NOT allowed # skip (R, S01) -> (S01, R) is NOT allowed
if len(source_spec.dim_partition_dict[b_index]) >= 2: if len(source_spec.dim_partition_dict[b_index]) >= 2:
continue continue
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
@ -409,7 +483,7 @@ class ShapeConsistencyManager:
self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost) return (transform_path, comm_action_sequence, total_cost)
temp_sharding_spec = deepcopy(source_spec) temp_sharding_spec = source_spec
transform_path.append(temp_sharding_spec) transform_path.append(temp_sharding_spec)
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
while total_steps <= MAX_TRANSFORM_STEPS: while total_steps <= MAX_TRANSFORM_STEPS:
@ -428,9 +502,9 @@ class ShapeConsistencyManager:
return (transform_path, comm_action_sequence, total_cost) return (transform_path, comm_action_sequence, total_cost)
if spec_difference < best_difference_score: if spec_difference < best_difference_score:
temp_sharding_spec = deepcopy(sharding_spec) temp_sharding_spec = sharding_spec
temp_cost = cost temp_cost = cost
temp_comm_spec = deepcopy(comm_spec) temp_comm_spec = comm_spec
best_difference_score = spec_difference best_difference_score = spec_difference
transform_path.append(temp_sharding_spec) transform_path.append(temp_sharding_spec)
@ -439,3 +513,67 @@ class ShapeConsistencyManager:
total_steps += 1 total_steps += 1
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
def apply(self, tensor_with_sharding_spec, target_spec):
'''
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
shape_consistency method.
Argument:
tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec.
target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec.
Example:
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = torch.Size((4, 2))
shape_consistency_manager = ShapeConsistencyManager()
dim_partition_source = {0: [0]}
dim_partition_target = {1: [0]}
# DistSpec:
# shard_sequence: S0,R
# device_mesh_shape: (2, 2)
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
# DistSpec:
# shard_sequence: R,S0
# device_mesh_shape: (2, 2)
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)
sharded_tensor_1 = torch.ones(2, 1)
# tensor([[0., 1.],
# [0., 1.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (2, 3):
sharded_tensor_0 = torch.ones(2, 1) * 2
sharded_tensor_1 = torch.ones(2, 1) * 3
# tensor([[2., 3.],
# [2., 3.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
tensor_to_comm.sharding_spec = sharding_spec_source
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
print(tensor_to_comm)
Output in rank0 and rank2:
tensor([[0.],
[0.],
[2.],
[2.]])
Output in rank1 and rank3:
tensor([[1.],
[1.],
[3.],
[3.]])
'''
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
for comm_spec in comm_action_sequence:
comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
tensor_with_sharding_spec.sharding_spec = target_spec

View File

@ -0,0 +1,49 @@
import torch
from functools import partial
import pytest
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.device.device_mesh import DeviceMesh
def check_layer(rank, world_size, port):
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
assert rank == gpc.get_global_rank()
tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda()
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
logical_process_groups = device_mesh.process_groups_dict
for mesh_dim, pgs in logical_pg_dict.items():
for index, pg in enumerate(pgs):
if rank in pg:
tensor = torch.ones(4).cuda()
group = logical_process_groups[mesh_dim][index][1]
dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
assert tensor.equal(tensor_to_check)
gpc.destroy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_logical_pg():
world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_logical_pg()

View File

@ -0,0 +1,177 @@
import torch
from functools import partial
import pytest
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CommSpec, CollectiveCommPattern
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.sharding_spec import ShardingSpec
def check_all_gather(device_mesh, rank):
# tensor to comm
if rank in (0, 2):
sharded_tensor_to_comm = torch.ones(2, 2).cuda()
else:
sharded_tensor_to_comm = torch.zeros(2, 2).cuda()
# tensor to check
tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda()
# test all gather
dim_partition_dict = {1: [1]}
# DistSpec:
# shard_sequence: R,S1
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.ALLGATHER, sharding_spec, gather_dim=1, logical_process_axis=1)
comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
assert sharded_tensor_to_comm.equal(tensor_to_check)
def check_shard(device_mesh, rank):
# tensor to comm
sharded_tensor_to_comm_0 = torch.zeros(2, 2).cuda()
sharded_tensor_to_comm_1 = torch.ones(2, 2).cuda()
# tensor([[0., 0., 1., 1.],
# [0., 0., 1., 1.]])
tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1)
# test shard
dim_partition_dict = {}
# DistSpec:
# shard_sequence: R,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_shard.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.SHARD, sharding_spec, shard_dim=1, logical_process_axis=1)
comm_spec.covert_spec_to_action(tensor_to_shard)
if rank in (0, 2):
assert tensor_to_shard.equal(sharded_tensor_to_comm_0)
if rank in (1, 3):
assert tensor_to_shard.equal(sharded_tensor_to_comm_1)
def check_all_to_all(device_mesh, rank):
# tensor to comm
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)
sharded_tensor_1 = torch.ones(2, 1)
# tensor([[0., 1.],
# [0., 1.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (2, 3):
sharded_tensor_0 = torch.ones(2, 1) * 2
sharded_tensor_1 = torch.ones(2, 1) * 3
# tensor([[2., 3.],
# [2., 3.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (0, 1):
# tensor([[0.],
# [0.],
# [2.],
# [2.]])
tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda()
if rank in (2, 3):
# tensor([[1.],
# [1.],
# [3.],
# [3.]])
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
# test shard
dim_partition_dict = {0: [0]}
# DistSpec:
# shard_sequence: S0,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.ALLTOALL,
sharding_spec,
gather_dim=0,
shard_dim=1,
logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_all_reduce(device_mesh, rank):
# tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank
# reduce through logical process axis 0
# tensor to check
if rank in (0, 2):
# tensor([[2., 2.],
# [2., 2.]])
tensor_to_check = torch.tensor([[2, 2], [2, 2]], dtype=tensor_to_comm.dtype).cuda()
if rank in (1, 3):
# tensor([[4., 4.],
# [4., 4.]])
tensor_to_check = torch.tensor([[4, 4], [4, 4]], dtype=tensor_to_comm.dtype).cuda()
dim_partition_dict = {}
# DistSpec:
# shard_sequence: R,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:CommSpec:(comm_pattern:all_reduce, logical_process_axis:0)
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE, sharding_spec, logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_comm(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
assert rank == gpc.get_global_rank()
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# test all gather
check_all_gather(device_mesh, rank)
# test shard
check_shard(device_mesh, rank)
# test all to all
check_all_to_all(device_mesh, rank)
# test all reduce
check_all_reduce(device_mesh, rank)
gpc.destroy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_comm_spec():
world_size = 4
run_func = partial(check_comm, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_comm_spec()

View File

@ -0,0 +1,81 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
def check_apply(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = torch.Size((4, 2))
shape_consistency_manager = ShapeConsistencyManager()
dim_partition_source = {0: [0]}
dim_partition_target = {1: [0]}
# DistSpec:
# shard_sequence: S0,R
# device_mesh_shape: (2, 2)
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
# DistSpec:
# shard_sequence: R,S0
# device_mesh_shape: (2, 2)
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)
sharded_tensor_1 = torch.ones(2, 1)
# tensor([[0., 1.],
# [0., 1.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (2, 3):
sharded_tensor_0 = torch.ones(2, 1) * 2
sharded_tensor_1 = torch.ones(2, 1) * 3
# tensor([[2., 3.],
# [2., 3.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (0, 1):
# tensor([[0.],
# [0.],
# [2.],
# [2.]])
tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda()
if rank in (2, 3):
# tensor([[1.],
# [1.],
# [3.],
# [3.]])
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
tensor_to_comm.sharding_spec = sharding_spec_source
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
print(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_apply():
world_size = 4
run_func = partial(check_apply, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_apply()