From 65e5d6baa51314414a6d0a3533226e978708408c Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 1 Feb 2024 13:33:09 +0800 Subject: [PATCH] [moe] fix mixtral optim checkpoint (#5344) --- .../colossal_moe/models/mixtral_checkpoint.py | 12 ++++++++++-- .../ColossalMoE/tests/test_moe_checkpoint.py | 2 ++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py index 629ad7349..d08dfd5f8 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py @@ -393,7 +393,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): # Store param groups. index_file.append_meta_data("param_groups", param_group_file) group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(optimizer.param_info, group_file_path) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) # Store index file. index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) @@ -451,7 +455,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): # Store param groups. final_index_file.append_meta_data("param_groups", param_group_file) group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(optimizer.param_info, group_file_path) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) final_index_file.write_index_file(final_index_file_path) rmtree(tmp_index_file_folder) diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py index d3848bc14..822e7410f 100644 --- a/applications/ColossalMoE/tests/test_moe_checkpoint.py +++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py @@ -117,6 +117,8 @@ def check_mixtral_moe_layer(): # check save optimizer optimizer.step() + for group in optimizer.param_groups: + group["lr"] = 0.1 snapshot = get_optimizer_snapshot(optimizer.unwrap()) booster.save_optimizer(optimizer, "mixtral_optim", shard=True) dist.barrier()