mirror of https://github.com/InternLM/InternLM
fix bug for loading ckpts when zero1 < dp_size
parent
056996f8b3
commit
96171d5f28
|
@ -423,6 +423,16 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
assert self.zero1_parallel_size > 0
|
||||
assert self.data_parallel_size % self.zero1_parallel_size == 0
|
||||
|
||||
# check for fsdp:
|
||||
# if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights
|
||||
# because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank
|
||||
# pytorch vision: 1.13.1+cu117
|
||||
if self.data_parallel_size > self.zero1_parallel_size:
|
||||
logger.warning(
|
||||
f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, \
|
||||
will introduce redundancy when saving ckpts, recommend setting them to same value"
|
||||
)
|
||||
|
||||
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
||||
if key in config:
|
||||
ele = config[key]
|
||||
|
|
|
@ -11,6 +11,7 @@ from enum import Enum
|
|||
from typing import Callable, Dict, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.api import load_with_process_group
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
|
||||
|
@ -193,7 +194,6 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
|||
load_content_str = ""
|
||||
load_ckpt_folder = load_info["path"]
|
||||
load_content: CheckpointLoadMask = load_info["content"]
|
||||
print(load_ckpt_folder, load_content)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
||||
|
||||
|
@ -277,7 +277,6 @@ def save_model_checkpoint(folder, model):
|
|||
if folder is not None:
|
||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
@ -287,7 +286,7 @@ def save_model_checkpoint(folder, model):
|
|||
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
||||
for i in range(tp_size):
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
for j in range(zo_size):
|
||||
for j in range(dp_size):
|
||||
should_save_rank_pair.add((i, j))
|
||||
else:
|
||||
should_save_rank_pair.add((i, i % dp_size))
|
||||
|
@ -320,10 +319,10 @@ def load_model_checkpoint(folder, model):
|
|||
|
||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
zo_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
|
||||
fns = get_fns(folder)
|
||||
|
||||
|
@ -353,14 +352,17 @@ def load_model_checkpoint(folder, model):
|
|||
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
assert (
|
||||
zo_size == max_zo + 1
|
||||
), f"The weights are save for {max_zo+1} FSDP shards , while current has {zo_size} FSDP shards"
|
||||
dp_size == max_zo + 1
|
||||
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"
|
||||
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt"
|
||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{dp_rank}.pt"
|
||||
else:
|
||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||
fp = os.path.join(folder, should_load_name)
|
||||
|
||||
# for FSDP shards loading, we need to set process group
|
||||
with load_with_process_group(gpc.get_group(ParallelMode.ZERO1)):
|
||||
states = llm_load(fp, map_location=get_current_device())
|
||||
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
|
|
Loading…
Reference in New Issue