From 154d3ef43219722de812089676760d2fb2667156 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 26 Sep 2022 16:39:37 +0800 Subject: [PATCH] [fix] fixed the collective pattern name for consistency (#1649) * [fix] fixed the collective pattern name for consistency * polish code --- colossalai/tensor/shape_consistency.py | 10 +++++----- tests/test_tensor/test_comm_spec_apply.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 3a1f04c8a..d094a2c37 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -235,7 +235,7 @@ class CollectiveCommPattern(Enum): GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' - REDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' + ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' @@ -290,8 +290,8 @@ class CommSpec: res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ") res_list.append(f"shard_dim:{self.shard_dim}, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") - elif self.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD: - res_list.append(f"comm_pattern:REDUCE_FWD_IDENTITY_BWD, ") + elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: + res_list.append(f"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") @@ -317,7 +317,7 @@ class CommSpec: # all to all operation has same logical process axis as forward. backward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) - if self.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD: + if self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: forward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) backward_communication_cost = 0 @@ -357,7 +357,7 @@ pattern_to_func_dict = { CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward, CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all, CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward, - CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD: reduce_input, + CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input, CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad, } diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py index dc51c59d6..245f374cc 100644 --- a/tests/test_tensor/test_comm_spec_apply.py +++ b/tests/test_tensor/test_comm_spec_apply.py @@ -136,7 +136,7 @@ def check_all_reduce_fwd(device_mesh, rank): # device_mesh_shape: (2, 2) sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0) + comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0) comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -177,7 +177,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) - comm_spec = CommSpec(CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1]) + comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1]) comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check)