mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] add shard dim to aviod backward communication error (#2954)
parent
090f14fd6b
commit
47fb214b3b
|
@ -343,6 +343,7 @@ class DefaultReshapeGenerator(ReshapeGenerator):
|
||||||
comm_type=CommType.BEFORE,
|
comm_type=CommType.BEFORE,
|
||||||
arg_index=0)
|
arg_index=0)
|
||||||
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
|
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
|
||||||
|
input_comm_action.comm_spec.shard_dim = total_mesh_dim_list
|
||||||
|
|
||||||
elif len(total_mesh_dim_list) >= 2:
|
elif len(total_mesh_dim_list) >= 2:
|
||||||
source_spec = sharding_spec_mapping["input"]
|
source_spec = sharding_spec_mapping["input"]
|
||||||
|
|
|
@ -429,6 +429,7 @@ class CommSpec:
|
||||||
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
|
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
|
||||||
res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ")
|
res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ")
|
||||||
res_list.append(f"gather_dim:{self.gather_dim}, ")
|
res_list.append(f"gather_dim:{self.gather_dim}, ")
|
||||||
|
res_list.append(f"shard_dim:{self.shard_dim}, ")
|
||||||
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
|
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
|
||||||
elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
|
elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
|
||||||
res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ")
|
res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ")
|
||||||
|
@ -437,6 +438,7 @@ class CommSpec:
|
||||||
res_list.append(f"logical_process_axis: {self.logical_process_axis})")
|
res_list.append(f"logical_process_axis: {self.logical_process_axis})")
|
||||||
elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
|
elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
|
||||||
res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ")
|
res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ")
|
||||||
|
res_list.append(f"gather_dim:{self.gather_dim}, ")
|
||||||
res_list.append(f"shard_dim:{self.shard_dim}, ")
|
res_list.append(f"shard_dim:{self.shard_dim}, ")
|
||||||
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
|
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
|
||||||
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
|
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
|
||||||
|
|
Loading…
Reference in New Issue