|
|
|
@ -429,6 +429,7 @@ class CommSpec:
|
|
|
|
|
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: |
|
|
|
|
res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ") |
|
|
|
|
res_list.append(f"gather_dim:{self.gather_dim}, ") |
|
|
|
|
res_list.append(f"shard_dim:{self.shard_dim}, ") |
|
|
|
|
res_list.append(f"logical_process_axis:{self.logical_process_axis})") |
|
|
|
|
elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: |
|
|
|
|
res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ") |
|
|
|
@ -437,6 +438,7 @@ class CommSpec:
|
|
|
|
|
res_list.append(f"logical_process_axis: {self.logical_process_axis})") |
|
|
|
|
elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: |
|
|
|
|
res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ") |
|
|
|
|
res_list.append(f"gather_dim:{self.gather_dim}, ") |
|
|
|
|
res_list.append(f"shard_dim:{self.shard_dim}, ") |
|
|
|
|
res_list.append(f"logical_process_axis:{self.logical_process_axis})") |
|
|
|
|
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: |
|
|
|
|