mirror of https://github.com/InternLM/InternLM
fix bug for loading ckpts when zero1 < dp_size
parent
056996f8b3
commit
96171d5f28
|
@ -423,6 +423,16 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
assert self.zero1_parallel_size > 0
|
assert self.zero1_parallel_size > 0
|
||||||
assert self.data_parallel_size % self.zero1_parallel_size == 0
|
assert self.data_parallel_size % self.zero1_parallel_size == 0
|
||||||
|
|
||||||
|
# check for fsdp:
|
||||||
|
# if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights
|
||||||
|
# because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank
|
||||||
|
# pytorch vision: 1.13.1+cu117
|
||||||
|
if self.data_parallel_size > self.zero1_parallel_size:
|
||||||
|
logger.warning(
|
||||||
|
f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, \
|
||||||
|
will introduce redundancy when saving ckpts, recommend setting them to same value"
|
||||||
|
)
|
||||||
|
|
||||||
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
||||||
if key in config:
|
if key in config:
|
||||||
ele = config[key]
|
ele = config[key]
|
||||||
|
|
|
@ -11,6 +11,7 @@ from enum import Enum
|
||||||
from typing import Callable, Dict, Union
|
from typing import Callable, Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.distributed._shard.api import load_with_process_group
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp import StateDictType
|
from torch.distributed.fsdp import StateDictType
|
||||||
|
|
||||||
|
@ -193,7 +194,6 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
||||||
load_content_str = ""
|
load_content_str = ""
|
||||||
load_ckpt_folder = load_info["path"]
|
load_ckpt_folder = load_info["path"]
|
||||||
load_content: CheckpointLoadMask = load_info["content"]
|
load_content: CheckpointLoadMask = load_info["content"]
|
||||||
print(load_ckpt_folder, load_content)
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
||||||
|
|
||||||
|
@ -277,7 +277,6 @@ def save_model_checkpoint(folder, model):
|
||||||
if folder is not None:
|
if folder is not None:
|
||||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
|
|
||||||
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
@ -287,7 +286,7 @@ def save_model_checkpoint(folder, model):
|
||||||
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
||||||
for i in range(tp_size):
|
for i in range(tp_size):
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
for j in range(zo_size):
|
for j in range(dp_size):
|
||||||
should_save_rank_pair.add((i, j))
|
should_save_rank_pair.add((i, j))
|
||||||
else:
|
else:
|
||||||
should_save_rank_pair.add((i, i % dp_size))
|
should_save_rank_pair.add((i, i % dp_size))
|
||||||
|
@ -320,10 +319,10 @@ def load_model_checkpoint(folder, model):
|
||||||
|
|
||||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
zo_size = gpc.get_world_size(ParallelMode.ZERO1)
|
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
zo_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||||
|
|
||||||
fns = get_fns(folder)
|
fns = get_fns(folder)
|
||||||
|
|
||||||
|
@ -353,15 +352,18 @@ def load_model_checkpoint(folder, model):
|
||||||
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
assert (
|
assert (
|
||||||
zo_size == max_zo + 1
|
dp_size == max_zo + 1
|
||||||
), f"The weights are save for {max_zo+1} FSDP shards , while current has {zo_size} FSDP shards"
|
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"
|
||||||
|
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt"
|
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{dp_rank}.pt"
|
||||||
else:
|
else:
|
||||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||||
fp = os.path.join(folder, should_load_name)
|
fp = os.path.join(folder, should_load_name)
|
||||||
states = llm_load(fp, map_location=get_current_device())
|
|
||||||
|
# for FSDP shards loading, we need to set process group
|
||||||
|
with load_with_process_group(gpc.get_group(ParallelMode.ZERO1)):
|
||||||
|
states = llm_load(fp, map_location=get_current_device())
|
||||||
|
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
|
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
|
||||||
|
|
Loading…
Reference in New Issue