mirror of https://github.com/hpcaitech/ColossalAI
[moe] fix mixtral checkpoint io (#5314)
parent
da39d21b71
commit
b60be18dcc
|
@ -135,6 +135,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.real_dp_rank != 0:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
# ep_rank 0 saves all the parameters and buffers.
|
||||
|
@ -171,6 +172,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
|
@ -201,10 +203,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
dist.barrier(self.pp_group)
|
||||
dist.barrier(self.ep_group)
|
||||
dist.barrier()
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.coordinator.is_master():
|
||||
|
@ -360,6 +362,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
# Devices along the same dp_group share the same copies of states when zero is not used.
|
||||
# In this case only let the device with dp_rank == 0 save the model.
|
||||
if not self.use_zero and self.real_dp_rank != 0:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||
|
@ -401,6 +404,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
|
@ -428,10 +432,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
dist.barrier(self.pp_group)
|
||||
dist.barrier(self.ep_group)
|
||||
dist.barrier()
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.coordinator.is_master():
|
||||
|
|
Loading…
Reference in New Issue