mirror of https://github.com/hpcaitech/ColossalAI
[moe] fix mixtral optim checkpoint (#5344)
parent
956b561b54
commit
65e5d6baa5
|
@ -393,7 +393,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
# Store param groups.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
|
@ -451,7 +455,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
# Store param groups.
|
||||
final_index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
|
|
@ -117,6 +117,8 @@ def check_mixtral_moe_layer():
|
|||
|
||||
# check save optimizer
|
||||
optimizer.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
|
||||
dist.barrier()
|
||||
|
|
Loading…
Reference in New Issue