ColossalAI/tests/test_tensor/test_comm_spec_apply.py

225 lines
7.5 KiB
Python
Raw Normal View History

import torch
from functools import partial
import pytest
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CommSpec, CollectiveCommPattern
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.sharding_spec import ShardingSpec
def check_all_gather(device_mesh, rank):
# tensor to comm
if rank in (0, 2):
sharded_tensor_to_comm = torch.ones(2, 2).cuda()
else:
sharded_tensor_to_comm = torch.zeros(2, 2).cuda()
# tensor to check
tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda()
# test all gather
dim_partition_dict = {1: [1]}
# DistSpec:
# shard_sequence: R,S1
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
sharding_spec,
gather_dim=1,
logical_process_axis=1)
comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
assert sharded_tensor_to_comm.equal(tensor_to_check)
def check_shard(device_mesh, rank):
# tensor to comm
sharded_tensor_to_comm_0 = torch.zeros(2, 2).cuda()
sharded_tensor_to_comm_1 = torch.ones(2, 2).cuda()
# tensor([[0., 0., 1., 1.],
# [0., 0., 1., 1.]])
tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1)
# test shard
dim_partition_dict = {}
# DistSpec:
# shard_sequence: R,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_shard.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1)
comm_spec.covert_spec_to_action(tensor_to_shard)
if rank in (0, 2):
assert tensor_to_shard.equal(sharded_tensor_to_comm_0)
if rank in (1, 3):
assert tensor_to_shard.equal(sharded_tensor_to_comm_1)
def check_all_to_all(device_mesh, rank):
# tensor to comm
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)
sharded_tensor_1 = torch.ones(2, 1)
# tensor([[0., 1.],
# [0., 1.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (2, 3):
sharded_tensor_0 = torch.ones(2, 1) * 2
sharded_tensor_1 = torch.ones(2, 1) * 3
# tensor([[2., 3.],
# [2., 3.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (0, 1):
# tensor([[0.],
# [0.],
# [2.],
# [2.]])
tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda()
if rank in (2, 3):
# tensor([[1.],
# [1.],
# [3.],
# [3.]])
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
# test shard
dim_partition_dict = {0: [0]}
# DistSpec:
# shard_sequence: S0,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD,
sharding_spec,
gather_dim=0,
shard_dim=1,
logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_all_reduce_fwd(device_mesh, rank):
# tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank
# reduce through logical process axis 0
# tensor to check
if rank in (0, 2):
# tensor([[2., 2.],
# [2., 2.]])
tensor_to_check = torch.tensor([[2, 2], [2, 2]], dtype=tensor_to_comm.dtype).cuda()
if rank in (1, 3):
# tensor([[4., 4.],
# [4., 4.]])
tensor_to_check = torch.tensor([[4, 4], [4, 4]], dtype=tensor_to_comm.dtype).cuda()
dim_partition_dict = {}
# DistSpec:
# shard_sequence: R,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_all_reduce_bwd(device_mesh, rank):
# tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank
tensor_to_check = torch.ones(2, 2).cuda() * rank
dim_partition_dict = {}
# DistSpec:
# shard_sequence: R,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
2022-08-25 08:48:12 +00:00
def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
# tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank
# reduce through logical process axis 0 at flatten device mesh
# tensor to check
# tensor([[6., 6.],
# [6., 6.]])
tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
dim_partition_dict = {}
# DistSpec:
# shard_sequence: R,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
2022-08-25 08:48:12 +00:00
comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_comm(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
assert rank == gpc.get_global_rank()
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# test all gather
check_all_gather(device_mesh, rank)
# test shard
check_shard(device_mesh, rank)
# test all to all
check_all_to_all(device_mesh, rank)
# test all reduce
check_all_reduce_fwd(device_mesh, rank)
check_all_reduce_bwd(device_mesh, rank)
2022-08-25 08:48:12 +00:00
# test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh(device_mesh, rank)
gpc.destroy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_comm_spec():
world_size = 4
run_func = partial(check_comm, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_comm_spec()