From 5b62a3957a7bca6ea16fa4681825cc33103a176a Mon Sep 17 00:00:00 2001 From: zaglc Date: Mon, 25 Sep 2023 16:08:40 +0800 Subject: [PATCH] fix load ckpt bug --- internlm/initialize/launch.py | 5 +- internlm/utils/model_checkpoint.py | 98 +++++++++++++++++++++++------- train.py | 2 +- 3 files changed, 80 insertions(+), 25 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 388051a..1c5f7a7 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -66,11 +66,12 @@ def args_sanity_check(): pp = gpc.config.parallel.pipeline else: pp = gpc.config.parallel.pipeline.size + tp = gpc.config.parallel.tensor if "use_fsdp" not in gpc.config.parallel: gpc.config.parallel._add_item("use_fsdp", False) - elif gpc.config.parallel.use_fsdp and pp > 1: - logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP") + elif gpc.config.parallel.use_fsdp and (pp > 1 or tp > 1): + logger.warning("FSDP not support when pipeline/tensor parallel is enabled, auto-close FSDP") gpc.config.parallel._add_item("use_fsdp", False) # processing the data config in gpc diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 1e5007b..1201571 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -11,7 +11,7 @@ from enum import Enum from typing import Callable, Dict, Union import torch -from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullStateDictConfig, LocalStateDictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType @@ -157,7 +157,7 @@ def get_model_topology(model): return topos -def get_state_dict(model): +def get_shard_state_dict(shard_model): """ Only used for FSDP module saving. It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter @@ -167,18 +167,33 @@ def get_state_dict(model): """ # TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters - save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): - states = model.state_dict() + # save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) + # with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): + # states = model.state_dict() - return states + # in this version, FSDP model can only save with sharded shape + with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT): + shard_states = shard_model.state_dict() + + return shard_states + +def load_shard_state_dict(shard_model, shard_state, **kwargs): + """ + Only used for FSDP module loading. + + """ + + with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT): + missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs) + + return (missing_k, unexpected_keys) def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): load_content_str = "" load_ckpt_folder = load_info["path"] load_content: CheckpointLoadMask = load_info["content"] - + print(load_ckpt_folder, load_content) if gpc.is_rank_for_log(): logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") @@ -241,6 +256,10 @@ def save_model_checkpoint(folder, model): - folder - model_tp{tp_rank}_pp{pp_rank}.pt + If fsdp is activated, the saved weight is named: + - folder + - model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank} + If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. Args: @@ -249,7 +268,7 @@ def save_model_checkpoint(folder, model): """ if gpc.config.parallel.use_fsdp: - states = get_state_dict(model) + states = get_shard_state_dict(model) else: states = model.state_dict() @@ -258,6 +277,7 @@ def save_model_checkpoint(folder, model): if folder is not None: dp_size = gpc.get_world_size(ParallelMode.DATA) tp_size = gpc.get_world_size(ParallelMode.TENSOR) + zo_size = gpc.get_world_size(ParallelMode.ZERO1) dp_rank = gpc.get_local_rank(ParallelMode.DATA) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) @@ -266,15 +286,21 @@ def save_model_checkpoint(folder, model): # even if pp is not considered, it will definitely not be written on the same machine. should_save_rank_pair = set() # (tp_rank, dp_rank) for i in range(tp_size): - should_save_rank_pair.add((i, i % dp_size)) + if gpc.config.parallel.use_fsdp: + for j in range(zo_size): + should_save_rank_pair.add((i, j)) + else: + should_save_rank_pair.add((i, i % dp_size)) - if (tp_rank, dp_rank) in should_save_rank_pair: - fn = f"model_tp{tp_rank}_pp{pp_rank}.pt" - fp = os.path.join(folder, fn) - llm_save(fp, saved_obj=states) - topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json" - topo_fp = os.path.join(folder, topo_fn) - llm_save(topo_fp, saved_obj=topo) + if (tp_rank, dp_rank) in should_save_rank_pair: + f_zo = f"_zo{dp_rank}" if gpc.config.parallel.use_fsdp else "" + fn = f"model_tp{tp_rank}_pp{pp_rank}{f_zo}.pt" + fp = os.path.join(folder, fn) + llm_save(fp, saved_obj=states) + if not gpc.config.parallel.use_fsdp or dp_rank == tp_rank % dp_size: + topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json" + topo_fp = os.path.join(folder, topo_fn) + llm_save(topo_fp, saved_obj=topo) torch.distributed.barrier() @@ -283,24 +309,46 @@ def load_model_checkpoint(folder, model): """ There should be weights with names similar to the following under the folder. - folder - - model_tp{tp_rank}_pp{pp_rank}.pt + - model_tp{tp_rank}_pp{pp_rank}.pt\ + + If fsdp is activated, the saved weight is named: + - folder + - model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank} If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. """ tp_size = gpc.get_world_size(ParallelMode.TENSOR) pp_size = gpc.get_world_size(ParallelMode.PIPELINE) + zo_size = gpc.get_world_size(ParallelMode.ZERO1) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + zo_rank = gpc.get_local_rank(ParallelMode.ZERO1) fns = get_fns(folder) - max_pp, max_tp = 0, 0 + + # avoid ckpt misuse between FSDP and no-FSDP + test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop() + assert ( + "_zo" in test_fn and gpc.config.parallel.use_fsdp or + "_zo" not in test_fn and not gpc.config.parallel.use_fsdp + ), f"FSDP model wants to load no-FSDP ckpts or reverse" + + max_pp, max_tp, max_zo = 0, 0, 0 for fn in fns: if fn.startswith("model_t") and not fn.endswith(".md5"): segements = os.path.splitext(fn)[0].split("_") - max_pp = max(max_pp, int(segements[-1][2:])) - max_tp = max(max_tp, int(segements[-2][2:])) + if gpc.config.parallel.use_fsdp: + max_zo = max(max_zo, int(segements[-1][2:])) + max_pp = max(max_pp, int(segements[-2][2:])) + max_tp = max(max_tp, int(segements[-3][2:])) + else: + max_pp = max(max_pp, int(segements[-1][2:])) + max_tp = max(max_tp, int(segements[-2][2:])) + assert ( + zo_size == max_zo + 1 + ), f"The weights are save for {max_zo+1} FSDP shards , while current has {zo_size} FSDP shards" assert ( pp_size == max_pp + 1 ), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines" @@ -308,11 +356,17 @@ def load_model_checkpoint(folder, model): tp_size == max_tp + 1 ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" - should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" + if gpc.config.parallel.use_fsdp: + should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt" + else: + should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" fp = os.path.join(folder, should_load_name) states = llm_load(fp, map_location=get_current_device()) - missing_k, unexpected_keys = model.load_state_dict(states, strict=False) + if gpc.config.parallel.use_fsdp: + missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False) + else: + missing_k, unexpected_keys = model.load_state_dict(states, strict=False) if len(missing_k) != 0: logger.warning(f"Warning: missing keys {missing_k}") if len(unexpected_keys) != 0: diff --git a/train.py b/train.py index cab61c7..7043fa0 100644 --- a/train.py +++ b/train.py @@ -111,7 +111,7 @@ def main(args): # initialize and resume train state train_state = TrainState(gpc.config, train_dl.batch_sampler) - # if fsdp enabled, warp the model + # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)