[moe] fix mixtral optim checkpoint (#5344)

pull/5372/head
Hongxin Liu 2024-02-01 13:33:09 +08:00 committed by ver217
parent 956b561b54
commit 65e5d6baa5
2 changed files with 12 additions and 2 deletions

View File

@ -393,7 +393,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
@ -451,7 +455,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)

View File

@ -117,6 +117,8 @@ def check_mixtral_moe_layer():
# check save optimizer
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
snapshot = get_optimizer_snapshot(optimizer.unwrap())
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
dist.barrier()