[moe] fix mixtral checkpoint io (#5314)

pull/5372/head
Hongxin Liu 2024-01-27 16:06:33 +08:00 committed by ver217
parent da39d21b71
commit b60be18dcc
1 changed files with 8 additions and 4 deletions

View File

@ -135,6 +135,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
Path(checkpoint).mkdir(parents=True, exist_ok=True)
if self.real_dp_rank != 0:
dist.barrier()
return
# ep_rank 0 saves all the parameters and buffers.
@ -171,6 +172,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
f"index located at {save_index_file}."
)
dist.barrier()
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
@ -201,10 +203,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
dist.barrier()
return
dist.barrier(self.pp_group)
dist.barrier(self.ep_group)
dist.barrier()
# The global master rank integrates the index files and clean the folder.
if self.coordinator.is_master():
@ -360,6 +362,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if not self.use_zero and self.real_dp_rank != 0:
dist.barrier()
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
@ -401,6 +404,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
f"index located at {save_index_file}."
)
dist.barrier()
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
@ -428,10 +432,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
dist.barrier()
return
dist.barrier(self.pp_group)
dist.barrier(self.ep_group)
dist.barrier()
# The global master rank integrates the index files and clean the folder.
if self.coordinator.is_master():