mirror of https://github.com/InternLM/InternLM
fix load ckpt bug
parent
37dbe6398b
commit
5b62a3957a
|
@ -66,11 +66,12 @@ def args_sanity_check():
|
|||
pp = gpc.config.parallel.pipeline
|
||||
else:
|
||||
pp = gpc.config.parallel.pipeline.size
|
||||
tp = gpc.config.parallel.tensor
|
||||
|
||||
if "use_fsdp" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("use_fsdp", False)
|
||||
elif gpc.config.parallel.use_fsdp and pp > 1:
|
||||
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
|
||||
elif gpc.config.parallel.use_fsdp and (pp > 1 or tp > 1):
|
||||
logger.warning("FSDP not support when pipeline/tensor parallel is enabled, auto-close FSDP")
|
||||
gpc.config.parallel._add_item("use_fsdp", False)
|
||||
|
||||
# processing the data config in gpc
|
||||
|
|
|
@ -11,7 +11,7 @@ from enum import Enum
|
|||
from typing import Callable, Dict, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullStateDictConfig, LocalStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
|
||||
|
@ -157,7 +157,7 @@ def get_model_topology(model):
|
|||
return topos
|
||||
|
||||
|
||||
def get_state_dict(model):
|
||||
def get_shard_state_dict(shard_model):
|
||||
"""
|
||||
Only used for FSDP module saving.
|
||||
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
|
||||
|
@ -167,18 +167,33 @@ def get_state_dict(model):
|
|||
"""
|
||||
|
||||
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
|
||||
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
|
||||
states = model.state_dict()
|
||||
# save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
||||
# with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
|
||||
# states = model.state_dict()
|
||||
|
||||
return states
|
||||
# in this version, FSDP model can only save with sharded shape
|
||||
with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT):
|
||||
shard_states = shard_model.state_dict()
|
||||
|
||||
return shard_states
|
||||
|
||||
def load_shard_state_dict(shard_model, shard_state, **kwargs):
|
||||
"""
|
||||
Only used for FSDP module loading.
|
||||
|
||||
"""
|
||||
|
||||
with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT):
|
||||
missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs)
|
||||
|
||||
return (missing_k, unexpected_keys)
|
||||
|
||||
|
||||
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
||||
load_content_str = ""
|
||||
load_ckpt_folder = load_info["path"]
|
||||
load_content: CheckpointLoadMask = load_info["content"]
|
||||
|
||||
print(load_ckpt_folder, load_content)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
||||
|
||||
|
@ -241,6 +256,10 @@ def save_model_checkpoint(folder, model):
|
|||
- folder
|
||||
- model_tp{tp_rank}_pp{pp_rank}.pt
|
||||
|
||||
If fsdp is activated, the saved weight is named:
|
||||
- folder
|
||||
- model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}
|
||||
|
||||
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
||||
|
||||
Args:
|
||||
|
@ -249,7 +268,7 @@ def save_model_checkpoint(folder, model):
|
|||
"""
|
||||
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
states = get_state_dict(model)
|
||||
states = get_shard_state_dict(model)
|
||||
else:
|
||||
states = model.state_dict()
|
||||
|
||||
|
@ -258,6 +277,7 @@ def save_model_checkpoint(folder, model):
|
|||
if folder is not None:
|
||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
@ -266,15 +286,21 @@ def save_model_checkpoint(folder, model):
|
|||
# even if pp is not considered, it will definitely not be written on the same machine.
|
||||
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
||||
for i in range(tp_size):
|
||||
should_save_rank_pair.add((i, i % dp_size))
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
for j in range(zo_size):
|
||||
should_save_rank_pair.add((i, j))
|
||||
else:
|
||||
should_save_rank_pair.add((i, i % dp_size))
|
||||
|
||||
if (tp_rank, dp_rank) in should_save_rank_pair:
|
||||
fn = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||
fp = os.path.join(folder, fn)
|
||||
llm_save(fp, saved_obj=states)
|
||||
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
|
||||
topo_fp = os.path.join(folder, topo_fn)
|
||||
llm_save(topo_fp, saved_obj=topo)
|
||||
if (tp_rank, dp_rank) in should_save_rank_pair:
|
||||
f_zo = f"_zo{dp_rank}" if gpc.config.parallel.use_fsdp else ""
|
||||
fn = f"model_tp{tp_rank}_pp{pp_rank}{f_zo}.pt"
|
||||
fp = os.path.join(folder, fn)
|
||||
llm_save(fp, saved_obj=states)
|
||||
if not gpc.config.parallel.use_fsdp or dp_rank == tp_rank % dp_size:
|
||||
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
|
||||
topo_fp = os.path.join(folder, topo_fn)
|
||||
llm_save(topo_fp, saved_obj=topo)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
@ -283,24 +309,46 @@ def load_model_checkpoint(folder, model):
|
|||
"""
|
||||
There should be weights with names similar to the following under the folder.
|
||||
- folder
|
||||
- model_tp{tp_rank}_pp{pp_rank}.pt
|
||||
- model_tp{tp_rank}_pp{pp_rank}.pt\
|
||||
|
||||
If fsdp is activated, the saved weight is named:
|
||||
- folder
|
||||
- model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}
|
||||
|
||||
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
||||
"""
|
||||
|
||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
zo_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
|
||||
fns = get_fns(folder)
|
||||
max_pp, max_tp = 0, 0
|
||||
|
||||
# avoid ckpt misuse between FSDP and no-FSDP
|
||||
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
|
||||
assert (
|
||||
"_zo" in test_fn and gpc.config.parallel.use_fsdp or
|
||||
"_zo" not in test_fn and not gpc.config.parallel.use_fsdp
|
||||
), f"FSDP model wants to load no-FSDP ckpts or reverse"
|
||||
|
||||
max_pp, max_tp, max_zo = 0, 0, 0
|
||||
for fn in fns:
|
||||
if fn.startswith("model_t") and not fn.endswith(".md5"):
|
||||
segements = os.path.splitext(fn)[0].split("_")
|
||||
max_pp = max(max_pp, int(segements[-1][2:]))
|
||||
max_tp = max(max_tp, int(segements[-2][2:]))
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
max_zo = max(max_zo, int(segements[-1][2:]))
|
||||
max_pp = max(max_pp, int(segements[-2][2:]))
|
||||
max_tp = max(max_tp, int(segements[-3][2:]))
|
||||
else:
|
||||
max_pp = max(max_pp, int(segements[-1][2:]))
|
||||
max_tp = max(max_tp, int(segements[-2][2:]))
|
||||
|
||||
assert (
|
||||
zo_size == max_zo + 1
|
||||
), f"The weights are save for {max_zo+1} FSDP shards , while current has {zo_size} FSDP shards"
|
||||
assert (
|
||||
pp_size == max_pp + 1
|
||||
), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
|
||||
|
@ -308,11 +356,17 @@ def load_model_checkpoint(folder, model):
|
|||
tp_size == max_tp + 1
|
||||
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
||||
|
||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt"
|
||||
else:
|
||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||
fp = os.path.join(folder, should_load_name)
|
||||
states = llm_load(fp, map_location=get_current_device())
|
||||
|
||||
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
|
||||
else:
|
||||
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
||||
if len(missing_k) != 0:
|
||||
logger.warning(f"Warning: missing keys {missing_k}")
|
||||
if len(unexpected_keys) != 0:
|
||||
|
|
2
train.py
2
train.py
|
@ -111,7 +111,7 @@ def main(args):
|
|||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||
|
||||
# if fsdp enabled, warp the model
|
||||
# if fsdp enabled, wrap the model
|
||||
model = wrap_FSDP_model(model)
|
||||
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||
|
|
Loading…
Reference in New Issue