mirror of https://github.com/InternLM/InternLM
fix load ckpt bug
parent
37dbe6398b
commit
5b62a3957a
|
@ -66,11 +66,12 @@ def args_sanity_check():
|
||||||
pp = gpc.config.parallel.pipeline
|
pp = gpc.config.parallel.pipeline
|
||||||
else:
|
else:
|
||||||
pp = gpc.config.parallel.pipeline.size
|
pp = gpc.config.parallel.pipeline.size
|
||||||
|
tp = gpc.config.parallel.tensor
|
||||||
|
|
||||||
if "use_fsdp" not in gpc.config.parallel:
|
if "use_fsdp" not in gpc.config.parallel:
|
||||||
gpc.config.parallel._add_item("use_fsdp", False)
|
gpc.config.parallel._add_item("use_fsdp", False)
|
||||||
elif gpc.config.parallel.use_fsdp and pp > 1:
|
elif gpc.config.parallel.use_fsdp and (pp > 1 or tp > 1):
|
||||||
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
|
logger.warning("FSDP not support when pipeline/tensor parallel is enabled, auto-close FSDP")
|
||||||
gpc.config.parallel._add_item("use_fsdp", False)
|
gpc.config.parallel._add_item("use_fsdp", False)
|
||||||
|
|
||||||
# processing the data config in gpc
|
# processing the data config in gpc
|
||||||
|
|
|
@ -11,7 +11,7 @@ from enum import Enum
|
||||||
from typing import Callable, Dict, Union
|
from typing import Callable, Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed.fsdp import FullStateDictConfig
|
from torch.distributed.fsdp import FullStateDictConfig, LocalStateDictConfig
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp import StateDictType
|
from torch.distributed.fsdp import StateDictType
|
||||||
|
|
||||||
|
@ -157,7 +157,7 @@ def get_model_topology(model):
|
||||||
return topos
|
return topos
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict(model):
|
def get_shard_state_dict(shard_model):
|
||||||
"""
|
"""
|
||||||
Only used for FSDP module saving.
|
Only used for FSDP module saving.
|
||||||
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
|
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
|
||||||
|
@ -167,18 +167,33 @@ def get_state_dict(model):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
|
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
|
||||||
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
# save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
||||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
|
# with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
|
||||||
states = model.state_dict()
|
# states = model.state_dict()
|
||||||
|
|
||||||
return states
|
# in this version, FSDP model can only save with sharded shape
|
||||||
|
with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT):
|
||||||
|
shard_states = shard_model.state_dict()
|
||||||
|
|
||||||
|
return shard_states
|
||||||
|
|
||||||
|
def load_shard_state_dict(shard_model, shard_state, **kwargs):
|
||||||
|
"""
|
||||||
|
Only used for FSDP module loading.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT):
|
||||||
|
missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs)
|
||||||
|
|
||||||
|
return (missing_k, unexpected_keys)
|
||||||
|
|
||||||
|
|
||||||
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
||||||
load_content_str = ""
|
load_content_str = ""
|
||||||
load_ckpt_folder = load_info["path"]
|
load_ckpt_folder = load_info["path"]
|
||||||
load_content: CheckpointLoadMask = load_info["content"]
|
load_content: CheckpointLoadMask = load_info["content"]
|
||||||
|
print(load_ckpt_folder, load_content)
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
||||||
|
|
||||||
|
@ -241,6 +256,10 @@ def save_model_checkpoint(folder, model):
|
||||||
- folder
|
- folder
|
||||||
- model_tp{tp_rank}_pp{pp_rank}.pt
|
- model_tp{tp_rank}_pp{pp_rank}.pt
|
||||||
|
|
||||||
|
If fsdp is activated, the saved weight is named:
|
||||||
|
- folder
|
||||||
|
- model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}
|
||||||
|
|
||||||
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -249,7 +268,7 @@ def save_model_checkpoint(folder, model):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
states = get_state_dict(model)
|
states = get_shard_state_dict(model)
|
||||||
else:
|
else:
|
||||||
states = model.state_dict()
|
states = model.state_dict()
|
||||||
|
|
||||||
|
@ -258,6 +277,7 @@ def save_model_checkpoint(folder, model):
|
||||||
if folder is not None:
|
if folder is not None:
|
||||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
|
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||||
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
@ -266,15 +286,21 @@ def save_model_checkpoint(folder, model):
|
||||||
# even if pp is not considered, it will definitely not be written on the same machine.
|
# even if pp is not considered, it will definitely not be written on the same machine.
|
||||||
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
||||||
for i in range(tp_size):
|
for i in range(tp_size):
|
||||||
should_save_rank_pair.add((i, i % dp_size))
|
if gpc.config.parallel.use_fsdp:
|
||||||
|
for j in range(zo_size):
|
||||||
|
should_save_rank_pair.add((i, j))
|
||||||
|
else:
|
||||||
|
should_save_rank_pair.add((i, i % dp_size))
|
||||||
|
|
||||||
if (tp_rank, dp_rank) in should_save_rank_pair:
|
if (tp_rank, dp_rank) in should_save_rank_pair:
|
||||||
fn = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
f_zo = f"_zo{dp_rank}" if gpc.config.parallel.use_fsdp else ""
|
||||||
fp = os.path.join(folder, fn)
|
fn = f"model_tp{tp_rank}_pp{pp_rank}{f_zo}.pt"
|
||||||
llm_save(fp, saved_obj=states)
|
fp = os.path.join(folder, fn)
|
||||||
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
|
llm_save(fp, saved_obj=states)
|
||||||
topo_fp = os.path.join(folder, topo_fn)
|
if not gpc.config.parallel.use_fsdp or dp_rank == tp_rank % dp_size:
|
||||||
llm_save(topo_fp, saved_obj=topo)
|
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
|
||||||
|
topo_fp = os.path.join(folder, topo_fn)
|
||||||
|
llm_save(topo_fp, saved_obj=topo)
|
||||||
|
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
@ -283,24 +309,46 @@ def load_model_checkpoint(folder, model):
|
||||||
"""
|
"""
|
||||||
There should be weights with names similar to the following under the folder.
|
There should be weights with names similar to the following under the folder.
|
||||||
- folder
|
- folder
|
||||||
- model_tp{tp_rank}_pp{pp_rank}.pt
|
- model_tp{tp_rank}_pp{pp_rank}.pt\
|
||||||
|
|
||||||
|
If fsdp is activated, the saved weight is named:
|
||||||
|
- folder
|
||||||
|
- model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}
|
||||||
|
|
||||||
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
|
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
zo_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||||
|
|
||||||
fns = get_fns(folder)
|
fns = get_fns(folder)
|
||||||
max_pp, max_tp = 0, 0
|
|
||||||
|
# avoid ckpt misuse between FSDP and no-FSDP
|
||||||
|
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
|
||||||
|
assert (
|
||||||
|
"_zo" in test_fn and gpc.config.parallel.use_fsdp or
|
||||||
|
"_zo" not in test_fn and not gpc.config.parallel.use_fsdp
|
||||||
|
), f"FSDP model wants to load no-FSDP ckpts or reverse"
|
||||||
|
|
||||||
|
max_pp, max_tp, max_zo = 0, 0, 0
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
if fn.startswith("model_t") and not fn.endswith(".md5"):
|
if fn.startswith("model_t") and not fn.endswith(".md5"):
|
||||||
segements = os.path.splitext(fn)[0].split("_")
|
segements = os.path.splitext(fn)[0].split("_")
|
||||||
max_pp = max(max_pp, int(segements[-1][2:]))
|
if gpc.config.parallel.use_fsdp:
|
||||||
max_tp = max(max_tp, int(segements[-2][2:]))
|
max_zo = max(max_zo, int(segements[-1][2:]))
|
||||||
|
max_pp = max(max_pp, int(segements[-2][2:]))
|
||||||
|
max_tp = max(max_tp, int(segements[-3][2:]))
|
||||||
|
else:
|
||||||
|
max_pp = max(max_pp, int(segements[-1][2:]))
|
||||||
|
max_tp = max(max_tp, int(segements[-2][2:]))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
zo_size == max_zo + 1
|
||||||
|
), f"The weights are save for {max_zo+1} FSDP shards , while current has {zo_size} FSDP shards"
|
||||||
assert (
|
assert (
|
||||||
pp_size == max_pp + 1
|
pp_size == max_pp + 1
|
||||||
), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
|
), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
|
||||||
|
@ -308,11 +356,17 @@ def load_model_checkpoint(folder, model):
|
||||||
tp_size == max_tp + 1
|
tp_size == max_tp + 1
|
||||||
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
||||||
|
|
||||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
if gpc.config.parallel.use_fsdp:
|
||||||
|
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt"
|
||||||
|
else:
|
||||||
|
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||||
fp = os.path.join(folder, should_load_name)
|
fp = os.path.join(folder, should_load_name)
|
||||||
states = llm_load(fp, map_location=get_current_device())
|
states = llm_load(fp, map_location=get_current_device())
|
||||||
|
|
||||||
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
if gpc.config.parallel.use_fsdp:
|
||||||
|
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
|
||||||
|
else:
|
||||||
|
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
||||||
if len(missing_k) != 0:
|
if len(missing_k) != 0:
|
||||||
logger.warning(f"Warning: missing keys {missing_k}")
|
logger.warning(f"Warning: missing keys {missing_k}")
|
||||||
if len(unexpected_keys) != 0:
|
if len(unexpected_keys) != 0:
|
||||||
|
|
2
train.py
2
train.py
|
@ -111,7 +111,7 @@ def main(args):
|
||||||
# initialize and resume train state
|
# initialize and resume train state
|
||||||
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||||
|
|
||||||
# if fsdp enabled, warp the model
|
# if fsdp enabled, wrap the model
|
||||||
model = wrap_FSDP_model(model)
|
model = wrap_FSDP_model(model)
|
||||||
|
|
||||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||||
|
|
Loading…
Reference in New Issue