diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py index 39983e918..24f75e352 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -343,6 +343,7 @@ class DefaultReshapeGenerator(ReshapeGenerator): comm_type=CommType.BEFORE, arg_index=0) input_comm_action.comm_spec.gather_dim = total_mesh_dim_list + input_comm_action.comm_spec.shard_dim = total_mesh_dim_list elif len(total_mesh_dim_list) >= 2: source_spec = sharding_spec_mapping["input"] diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index b31c06994..0d8de1062 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -429,6 +429,7 @@ class CommSpec: if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ") res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ") @@ -437,6 +438,7 @@ class CommSpec: res_list.append(f"logical_process_axis: {self.logical_process_axis})") elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") res_list.append(f"shard_dim:{self.shard_dim}, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: