mirror of https://github.com/InternLM/InternLM
feat(configs/7B_sft.py): move fsdp config to parallel zero1
parent
bd809a61f2
commit
edd7f9e8e1
|
@ -124,7 +124,7 @@ pipeline parallel: pipeline parallel size, only 1 is accepted currently.
|
|||
tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=8,
|
||||
zero1=dict(size=8, fsdp=False),
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
|
|
|
@ -154,7 +154,7 @@ pipeline parallel (dict):
|
|||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=8,
|
||||
zero1=dict(size=8, fsdp=False),
|
||||
tensor=1,
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=False,
|
||||
|
|
|
@ -436,10 +436,10 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
# if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights
|
||||
# because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank
|
||||
# pytorch vision: 1.13.1+cu117
|
||||
if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.get("use_fsdp", False):
|
||||
if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.zero1.get("fsdp", False):
|
||||
logger.warning(
|
||||
f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, "
|
||||
"will introduce redundancy when saving ckpts, recommend setting them to same value"
|
||||
"will introduce redundancy when saving fsdp model ckpts, recommend setting them to same value"
|
||||
)
|
||||
|
||||
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
||||
|
@ -508,7 +508,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
||||
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
||||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||
if self.config.parallel.get("use_fsdp", False):
|
||||
if self.config.parallel.zero1.get("fsdp", False):
|
||||
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
|
||||
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
||||
if self.pipeline_parallel_size > 1:
|
||||
|
|
|
@ -79,12 +79,14 @@ def args_sanity_check():
|
|||
pp = gpc.config.parallel.pipeline.size
|
||||
|
||||
# check fsdp config
|
||||
if "use_fsdp" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("use_fsdp", False)
|
||||
if "fsdp" not in gpc.config.parallel.zero1:
|
||||
gpc.config.parallel.zero1._add_item("fsdp", False)
|
||||
|
||||
assert not (
|
||||
gpc.config.parallel.use_fsdp and pp > 1
|
||||
), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
gpc.config.parallel.zero1.fsdp and pp > 1
|
||||
), "FSDP is not supportted when pipeline size > 1, please set pipeline size to 1 or disabled FSDP"
|
||||
|
||||
if gpc.config.parallel.zero1.fsdp:
|
||||
assert (
|
||||
torch.__version__ >= "2.0.1"
|
||||
), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}"
|
||||
|
@ -288,7 +290,7 @@ def args_sanity_check():
|
|||
if "moe_gate_k" not in model:
|
||||
model._add_item("moe_gate_k", 2)
|
||||
assert not (
|
||||
gpc.config.model.num_experts > 1 and gpc.config.parallel.use_fsdp
|
||||
gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp
|
||||
), "FSDP does not support num_experts > 1"
|
||||
|
||||
# process the parallel config
|
||||
|
|
|
@ -20,7 +20,7 @@ logger = get_logger(__file__)
|
|||
|
||||
class FSDPadaptOptimizer(BaseOptimizer):
|
||||
"""
|
||||
optimizer for Pytorch FSDP if 'use_fsdp' is True in config file
|
||||
optimizer for Pytorch FSDP if 'parallel.zero1.fsdp' is True in config file
|
||||
reserve some necessary components of hybird-optim:
|
||||
grad_scaler;
|
||||
grad_clip and unscale;
|
||||
|
@ -48,7 +48,6 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
|
||||
# clip gradient
|
||||
self._clip_grad_norm = zero_cfg.clip_grad_norm
|
||||
self.use_fsdp = gpc.config.parallel.use_fsdp
|
||||
|
||||
# fp16 and fp32 params
|
||||
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
|
||||
|
|
|
@ -111,7 +111,7 @@ def initialize_model():
|
|||
|
||||
|
||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
if gpc.config.parallel.zero1.fsdp:
|
||||
# set wrap_policy for fsdp wrap
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
|
@ -168,7 +168,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
|||
eps=adam_cfg.adam_eps,
|
||||
)
|
||||
|
||||
if not gpc.config.parallel.use_fsdp:
|
||||
if not gpc.config.parallel.zero1.fsdp:
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer,
|
||||
grad_scal_cfg=gpc.config.grad_scaler,
|
||||
|
|
|
@ -265,7 +265,7 @@ def save_model_checkpoint(folder, model):
|
|||
model: The model to be saved
|
||||
"""
|
||||
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
if gpc.config.parallel.zero1.fsdp:
|
||||
states = get_shard_state_dict(model)
|
||||
else:
|
||||
states = model.state_dict()
|
||||
|
@ -285,18 +285,18 @@ def save_model_checkpoint(folder, model):
|
|||
# even if pp is not considered, it will definitely not be written on the same machine.
|
||||
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
||||
for i in range(tp_size):
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
if gpc.config.parallel.zero1.fsdp:
|
||||
for j in range(dp_size):
|
||||
should_save_rank_pair.add((i, j))
|
||||
else:
|
||||
should_save_rank_pair.add((i, i % dp_size))
|
||||
|
||||
if (tp_rank, dp_rank) in should_save_rank_pair:
|
||||
f_dp = f"_dp{dp_rank}" if gpc.config.parallel.use_fsdp else ""
|
||||
f_dp = f"_dp{dp_rank}" if gpc.config.parallel.zero1.fsdp else ""
|
||||
fn = f"model_tp{tp_rank}_pp{pp_rank}{f_dp}.pt"
|
||||
fp = os.path.join(folder, fn)
|
||||
llm_save(fp, saved_obj=states)
|
||||
if not gpc.config.parallel.use_fsdp or dp_rank == tp_rank % dp_size:
|
||||
if not gpc.config.parallel.zero1.fsdp or dp_rank == tp_rank % dp_size:
|
||||
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
|
||||
topo_fp = os.path.join(folder, topo_fn)
|
||||
llm_save(topo_fp, saved_obj=topo)
|
||||
|
@ -338,15 +338,15 @@ def load_model_checkpoint(folder, model):
|
|||
|
||||
# avoid ckpt misuse between FSDP and no-FSDP
|
||||
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
|
||||
assert ("_dp" in test_fn and gpc.config.parallel.use_fsdp) or (
|
||||
"_dp" not in test_fn and not gpc.config.parallel.use_fsdp
|
||||
assert ("_dp" in test_fn and gpc.config.parallel.zero1.fsdp) or (
|
||||
"_dp" not in test_fn and not gpc.config.parallel.zero1.fsdp
|
||||
), "FSDP model wants to load no-FSDP ckpts or reverse"
|
||||
|
||||
max_pp, max_tp, max_zo = 0, 0, 0
|
||||
for fn in fns:
|
||||
if fn.startswith("model_t") and not fn.endswith(".md5"):
|
||||
segements = os.path.splitext(fn)[0].split("_")
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
if gpc.config.parallel.zero1.fsdp:
|
||||
max_zo = max(max_zo, int(segements[-1][2:]))
|
||||
max_pp = max(max_pp, int(segements[-2][2:]))
|
||||
max_tp = max(max_tp, int(segements[-3][2:]))
|
||||
|
@ -360,12 +360,12 @@ def load_model_checkpoint(folder, model):
|
|||
assert (
|
||||
tp_size == max_tp + 1
|
||||
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
if gpc.config.parallel.zero1.fsdp:
|
||||
assert (
|
||||
dp_size == max_zo + 1
|
||||
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"
|
||||
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
if gpc.config.parallel.zero1.fsdp:
|
||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
|
||||
else:
|
||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||
|
@ -388,7 +388,7 @@ def load_model_checkpoint(folder, model):
|
|||
# try to load expert parameter to separate files if model have moe layer
|
||||
try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank)
|
||||
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
if gpc.config.parallel.zero1.fsdp:
|
||||
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
|
||||
else:
|
||||
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
||||
|
|
Loading…
Reference in New Issue