fix load ckpt bug

pull/293/head
zaglc 2023-09-25 16:08:40 +08:00
parent 37dbe6398b
commit 5b62a3957a
3 changed files with 80 additions and 25 deletions

View File

@ -66,11 +66,12 @@ def args_sanity_check():
pp = gpc.config.parallel.pipeline pp = gpc.config.parallel.pipeline
else: else:
pp = gpc.config.parallel.pipeline.size pp = gpc.config.parallel.pipeline.size
tp = gpc.config.parallel.tensor
if "use_fsdp" not in gpc.config.parallel: if "use_fsdp" not in gpc.config.parallel:
gpc.config.parallel._add_item("use_fsdp", False) gpc.config.parallel._add_item("use_fsdp", False)
elif gpc.config.parallel.use_fsdp and pp > 1: elif gpc.config.parallel.use_fsdp and (pp > 1 or tp > 1):
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP") logger.warning("FSDP not support when pipeline/tensor parallel is enabled, auto-close FSDP")
gpc.config.parallel._add_item("use_fsdp", False) gpc.config.parallel._add_item("use_fsdp", False)
# processing the data config in gpc # processing the data config in gpc

View File

@ -11,7 +11,7 @@ from enum import Enum
from typing import Callable, Dict, Union from typing import Callable, Dict, Union
import torch import torch
from torch.distributed.fsdp import FullStateDictConfig from torch.distributed.fsdp import FullStateDictConfig, LocalStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType from torch.distributed.fsdp import StateDictType
@ -157,7 +157,7 @@ def get_model_topology(model):
return topos return topos
def get_state_dict(model): def get_shard_state_dict(shard_model):
""" """
Only used for FSDP module saving. Only used for FSDP module saving.
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
@ -167,18 +167,33 @@ def get_state_dict(model):
""" """
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters # TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) # save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): # with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
states = model.state_dict() # states = model.state_dict()
return states # in this version, FSDP model can only save with sharded shape
with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT):
shard_states = shard_model.state_dict()
return shard_states
def load_shard_state_dict(shard_model, shard_state, **kwargs):
"""
Only used for FSDP module loading.
"""
with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT):
missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs)
return (missing_k, unexpected_keys)
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
load_content_str = "" load_content_str = ""
load_ckpt_folder = load_info["path"] load_ckpt_folder = load_info["path"]
load_content: CheckpointLoadMask = load_info["content"] load_content: CheckpointLoadMask = load_info["content"]
print(load_ckpt_folder, load_content)
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
@ -241,6 +256,10 @@ def save_model_checkpoint(folder, model):
- folder - folder
- model_tp{tp_rank}_pp{pp_rank}.pt - model_tp{tp_rank}_pp{pp_rank}.pt
If fsdp is activated, the saved weight is named:
- folder
- model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
Args: Args:
@ -249,7 +268,7 @@ def save_model_checkpoint(folder, model):
""" """
if gpc.config.parallel.use_fsdp: if gpc.config.parallel.use_fsdp:
states = get_state_dict(model) states = get_shard_state_dict(model)
else: else:
states = model.state_dict() states = model.state_dict()
@ -258,6 +277,7 @@ def save_model_checkpoint(folder, model):
if folder is not None: if folder is not None:
dp_size = gpc.get_world_size(ParallelMode.DATA) dp_size = gpc.get_world_size(ParallelMode.DATA)
tp_size = gpc.get_world_size(ParallelMode.TENSOR) tp_size = gpc.get_world_size(ParallelMode.TENSOR)
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
dp_rank = gpc.get_local_rank(ParallelMode.DATA) dp_rank = gpc.get_local_rank(ParallelMode.DATA)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
@ -266,15 +286,21 @@ def save_model_checkpoint(folder, model):
# even if pp is not considered, it will definitely not be written on the same machine. # even if pp is not considered, it will definitely not be written on the same machine.
should_save_rank_pair = set() # (tp_rank, dp_rank) should_save_rank_pair = set() # (tp_rank, dp_rank)
for i in range(tp_size): for i in range(tp_size):
should_save_rank_pair.add((i, i % dp_size)) if gpc.config.parallel.use_fsdp:
for j in range(zo_size):
should_save_rank_pair.add((i, j))
else:
should_save_rank_pair.add((i, i % dp_size))
if (tp_rank, dp_rank) in should_save_rank_pair: if (tp_rank, dp_rank) in should_save_rank_pair:
fn = f"model_tp{tp_rank}_pp{pp_rank}.pt" f_zo = f"_zo{dp_rank}" if gpc.config.parallel.use_fsdp else ""
fp = os.path.join(folder, fn) fn = f"model_tp{tp_rank}_pp{pp_rank}{f_zo}.pt"
llm_save(fp, saved_obj=states) fp = os.path.join(folder, fn)
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json" llm_save(fp, saved_obj=states)
topo_fp = os.path.join(folder, topo_fn) if not gpc.config.parallel.use_fsdp or dp_rank == tp_rank % dp_size:
llm_save(topo_fp, saved_obj=topo) topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo)
torch.distributed.barrier() torch.distributed.barrier()
@ -283,24 +309,46 @@ def load_model_checkpoint(folder, model):
""" """
There should be weights with names similar to the following under the folder. There should be weights with names similar to the following under the folder.
- folder - folder
- model_tp{tp_rank}_pp{pp_rank}.pt - model_tp{tp_rank}_pp{pp_rank}.pt\
If fsdp is activated, the saved weight is named:
- folder
- model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
""" """
tp_size = gpc.get_world_size(ParallelMode.TENSOR) tp_size = gpc.get_world_size(ParallelMode.TENSOR)
pp_size = gpc.get_world_size(ParallelMode.PIPELINE) pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
zo_rank = gpc.get_local_rank(ParallelMode.ZERO1)
fns = get_fns(folder) fns = get_fns(folder)
max_pp, max_tp = 0, 0
# avoid ckpt misuse between FSDP and no-FSDP
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
assert (
"_zo" in test_fn and gpc.config.parallel.use_fsdp or
"_zo" not in test_fn and not gpc.config.parallel.use_fsdp
), f"FSDP model wants to load no-FSDP ckpts or reverse"
max_pp, max_tp, max_zo = 0, 0, 0
for fn in fns: for fn in fns:
if fn.startswith("model_t") and not fn.endswith(".md5"): if fn.startswith("model_t") and not fn.endswith(".md5"):
segements = os.path.splitext(fn)[0].split("_") segements = os.path.splitext(fn)[0].split("_")
max_pp = max(max_pp, int(segements[-1][2:])) if gpc.config.parallel.use_fsdp:
max_tp = max(max_tp, int(segements[-2][2:])) max_zo = max(max_zo, int(segements[-1][2:]))
max_pp = max(max_pp, int(segements[-2][2:]))
max_tp = max(max_tp, int(segements[-3][2:]))
else:
max_pp = max(max_pp, int(segements[-1][2:]))
max_tp = max(max_tp, int(segements[-2][2:]))
assert (
zo_size == max_zo + 1
), f"The weights are save for {max_zo+1} FSDP shards , while current has {zo_size} FSDP shards"
assert ( assert (
pp_size == max_pp + 1 pp_size == max_pp + 1
), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines" ), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
@ -308,11 +356,17 @@ def load_model_checkpoint(folder, model):
tp_size == max_tp + 1 tp_size == max_tp + 1
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" if gpc.config.parallel.use_fsdp:
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt"
else:
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
fp = os.path.join(folder, should_load_name) fp = os.path.join(folder, should_load_name)
states = llm_load(fp, map_location=get_current_device()) states = llm_load(fp, map_location=get_current_device())
missing_k, unexpected_keys = model.load_state_dict(states, strict=False) if gpc.config.parallel.use_fsdp:
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
else:
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
if len(missing_k) != 0: if len(missing_k) != 0:
logger.warning(f"Warning: missing keys {missing_k}") logger.warning(f"Warning: missing keys {missing_k}")
if len(unexpected_keys) != 0: if len(unexpected_keys) != 0:

View File

@ -111,7 +111,7 @@ def main(args):
# initialize and resume train state # initialize and resume train state
train_state = TrainState(gpc.config, train_dl.batch_sampler) train_state = TrainState(gpc.config, train_dl.batch_sampler)
# if fsdp enabled, warp the model # if fsdp enabled, wrap the model
model = wrap_FSDP_model(model) model = wrap_FSDP_model(model)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)