From 96171d5f2816a9e8166fe2a0e4e1a2394881aff7 Mon Sep 17 00:00:00 2001 From: zaglc Date: Tue, 26 Sep 2023 17:36:59 +0800 Subject: [PATCH] fix bug for loading ckpts when zero1 < dp_size --- internlm/core/context/parallel_context.py | 10 ++++++++++ internlm/utils/model_checkpoint.py | 20 +++++++++++--------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 3100236..16c84a8 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -423,6 +423,16 @@ class ParallelContext(metaclass=SingletonMeta): assert self.zero1_parallel_size > 0 assert self.data_parallel_size % self.zero1_parallel_size == 0 + # check for fsdp: + # if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights + # because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank + # pytorch vision: 1.13.1+cu117 + if self.data_parallel_size > self.zero1_parallel_size: + logger.warning( + f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, \ + will introduce redundancy when saving ckpts, recommend setting them to same value" + ) + def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): if key in config: ele = config[key] diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 095a61b..272fdae 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -11,6 +11,7 @@ from enum import Enum from typing import Callable, Dict, Union import torch +from torch.distributed._shard.api import load_with_process_group from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType @@ -193,7 +194,6 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): load_content_str = "" load_ckpt_folder = load_info["path"] load_content: CheckpointLoadMask = load_info["content"] - print(load_ckpt_folder, load_content) if gpc.is_rank_for_log(): logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") @@ -277,7 +277,6 @@ def save_model_checkpoint(folder, model): if folder is not None: dp_size = gpc.get_world_size(ParallelMode.DATA) tp_size = gpc.get_world_size(ParallelMode.TENSOR) - zo_size = gpc.get_world_size(ParallelMode.ZERO1) dp_rank = gpc.get_local_rank(ParallelMode.DATA) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) @@ -287,7 +286,7 @@ def save_model_checkpoint(folder, model): should_save_rank_pair = set() # (tp_rank, dp_rank) for i in range(tp_size): if gpc.config.parallel.use_fsdp: - for j in range(zo_size): + for j in range(dp_size): should_save_rank_pair.add((i, j)) else: should_save_rank_pair.add((i, i % dp_size)) @@ -320,10 +319,10 @@ def load_model_checkpoint(folder, model): tp_size = gpc.get_world_size(ParallelMode.TENSOR) pp_size = gpc.get_world_size(ParallelMode.PIPELINE) - zo_size = gpc.get_world_size(ParallelMode.ZERO1) + dp_size = gpc.get_world_size(ParallelMode.DATA) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - zo_rank = gpc.get_local_rank(ParallelMode.ZERO1) + dp_rank = gpc.get_local_rank(ParallelMode.DATA) fns = get_fns(folder) @@ -353,15 +352,18 @@ def load_model_checkpoint(folder, model): ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" if gpc.config.parallel.use_fsdp: assert ( - zo_size == max_zo + 1 - ), f"The weights are save for {max_zo+1} FSDP shards , while current has {zo_size} FSDP shards" + dp_size == max_zo + 1 + ), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards" if gpc.config.parallel.use_fsdp: - should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt" + should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{dp_rank}.pt" else: should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" fp = os.path.join(folder, should_load_name) - states = llm_load(fp, map_location=get_current_device()) + + # for FSDP shards loading, we need to set process group + with load_with_process_group(gpc.get_group(ParallelMode.ZERO1)): + states = llm_load(fp, map_location=get_current_device()) if gpc.config.parallel.use_fsdp: missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)