fix bug for loading ckpts when zero1 < dp_size

pull/293/head
zaglc 2023-09-26 17:36:59 +08:00
parent 056996f8b3
commit 96171d5f28
2 changed files with 21 additions and 9 deletions

View File

@ -423,6 +423,16 @@ class ParallelContext(metaclass=SingletonMeta):
assert self.zero1_parallel_size > 0
assert self.data_parallel_size % self.zero1_parallel_size == 0
# check for fsdp:
# if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights
# because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank
# pytorch vision: 1.13.1+cu117
if self.data_parallel_size > self.zero1_parallel_size:
logger.warning(
f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, \
will introduce redundancy when saving ckpts, recommend setting them to same value"
)
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
if key in config:
ele = config[key]

View File

@ -11,6 +11,7 @@ from enum import Enum
from typing import Callable, Dict, Union
import torch
from torch.distributed._shard.api import load_with_process_group
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
@ -193,7 +194,6 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
load_content_str = ""
load_ckpt_folder = load_info["path"]
load_content: CheckpointLoadMask = load_info["content"]
print(load_ckpt_folder, load_content)
if gpc.is_rank_for_log():
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
@ -277,7 +277,6 @@ def save_model_checkpoint(folder, model):
if folder is not None:
dp_size = gpc.get_world_size(ParallelMode.DATA)
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
@ -287,7 +286,7 @@ def save_model_checkpoint(folder, model):
should_save_rank_pair = set() # (tp_rank, dp_rank)
for i in range(tp_size):
if gpc.config.parallel.use_fsdp:
for j in range(zo_size):
for j in range(dp_size):
should_save_rank_pair.add((i, j))
else:
should_save_rank_pair.add((i, i % dp_size))
@ -320,10 +319,10 @@ def load_model_checkpoint(folder, model):
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
dp_size = gpc.get_world_size(ParallelMode.DATA)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
zo_rank = gpc.get_local_rank(ParallelMode.ZERO1)
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
fns = get_fns(folder)
@ -353,14 +352,17 @@ def load_model_checkpoint(folder, model):
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
if gpc.config.parallel.use_fsdp:
assert (
zo_size == max_zo + 1
), f"The weights are save for {max_zo+1} FSDP shards , while current has {zo_size} FSDP shards"
dp_size == max_zo + 1
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"
if gpc.config.parallel.use_fsdp:
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt"
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{dp_rank}.pt"
else:
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
fp = os.path.join(folder, should_load_name)
# for FSDP shards loading, we need to set process group
with load_with_process_group(gpc.get_group(ParallelMode.ZERO1)):
states = llm_load(fp, map_location=get_current_device())
if gpc.config.parallel.use_fsdp: