mirror of https://github.com/InternLM/InternLM
feat(configs/7B_sft.py): move fsdp config to parallel zero1
parent
bd809a61f2
commit
edd7f9e8e1
|
@ -124,7 +124,7 @@ pipeline parallel: pipeline parallel size, only 1 is accepted currently.
|
||||||
tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently.
|
tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=8,
|
zero1=dict(size=8, fsdp=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
cudnn_deterministic = False
|
cudnn_deterministic = False
|
||||||
|
|
|
@ -154,7 +154,7 @@ pipeline parallel (dict):
|
||||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=8,
|
zero1=dict(size=8, fsdp=False),
|
||||||
tensor=1,
|
tensor=1,
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
sequence_parallel=False,
|
sequence_parallel=False,
|
||||||
|
|
|
@ -436,10 +436,10 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
# if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights
|
# if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights
|
||||||
# because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank
|
# because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank
|
||||||
# pytorch vision: 1.13.1+cu117
|
# pytorch vision: 1.13.1+cu117
|
||||||
if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.get("use_fsdp", False):
|
if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.zero1.get("fsdp", False):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, "
|
f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, "
|
||||||
"will introduce redundancy when saving ckpts, recommend setting them to same value"
|
"will introduce redundancy when saving fsdp model ckpts, recommend setting them to same value"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
||||||
|
@ -508,7 +508,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
||||||
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
||||||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||||
if self.config.parallel.get("use_fsdp", False):
|
if self.config.parallel.zero1.get("fsdp", False):
|
||||||
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
|
||||||
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
||||||
if self.pipeline_parallel_size > 1:
|
if self.pipeline_parallel_size > 1:
|
||||||
|
|
|
@ -79,12 +79,14 @@ def args_sanity_check():
|
||||||
pp = gpc.config.parallel.pipeline.size
|
pp = gpc.config.parallel.pipeline.size
|
||||||
|
|
||||||
# check fsdp config
|
# check fsdp config
|
||||||
if "use_fsdp" not in gpc.config.parallel:
|
if "fsdp" not in gpc.config.parallel.zero1:
|
||||||
gpc.config.parallel._add_item("use_fsdp", False)
|
gpc.config.parallel.zero1._add_item("fsdp", False)
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
gpc.config.parallel.use_fsdp and pp > 1
|
gpc.config.parallel.zero1.fsdp and pp > 1
|
||||||
), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
|
), "FSDP is not supportted when pipeline size > 1, please set pipeline size to 1 or disabled FSDP"
|
||||||
if gpc.config.parallel.use_fsdp:
|
|
||||||
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
assert (
|
assert (
|
||||||
torch.__version__ >= "2.0.1"
|
torch.__version__ >= "2.0.1"
|
||||||
), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}"
|
), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}"
|
||||||
|
@ -288,7 +290,7 @@ def args_sanity_check():
|
||||||
if "moe_gate_k" not in model:
|
if "moe_gate_k" not in model:
|
||||||
model._add_item("moe_gate_k", 2)
|
model._add_item("moe_gate_k", 2)
|
||||||
assert not (
|
assert not (
|
||||||
gpc.config.model.num_experts > 1 and gpc.config.parallel.use_fsdp
|
gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp
|
||||||
), "FSDP does not support num_experts > 1"
|
), "FSDP does not support num_experts > 1"
|
||||||
|
|
||||||
# process the parallel config
|
# process the parallel config
|
||||||
|
|
|
@ -20,7 +20,7 @@ logger = get_logger(__file__)
|
||||||
|
|
||||||
class FSDPadaptOptimizer(BaseOptimizer):
|
class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
"""
|
"""
|
||||||
optimizer for Pytorch FSDP if 'use_fsdp' is True in config file
|
optimizer for Pytorch FSDP if 'parallel.zero1.fsdp' is True in config file
|
||||||
reserve some necessary components of hybird-optim:
|
reserve some necessary components of hybird-optim:
|
||||||
grad_scaler;
|
grad_scaler;
|
||||||
grad_clip and unscale;
|
grad_clip and unscale;
|
||||||
|
@ -48,7 +48,6 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# clip gradient
|
# clip gradient
|
||||||
self._clip_grad_norm = zero_cfg.clip_grad_norm
|
self._clip_grad_norm = zero_cfg.clip_grad_norm
|
||||||
self.use_fsdp = gpc.config.parallel.use_fsdp
|
|
||||||
|
|
||||||
# fp16 and fp32 params
|
# fp16 and fp32 params
|
||||||
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
|
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
|
||||||
|
|
|
@ -111,7 +111,7 @@ def initialize_model():
|
||||||
|
|
||||||
|
|
||||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
# set wrap_policy for fsdp wrap
|
# set wrap_policy for fsdp wrap
|
||||||
transformer_wrap_policy = functools.partial(
|
transformer_wrap_policy = functools.partial(
|
||||||
transformer_auto_wrap_policy,
|
transformer_auto_wrap_policy,
|
||||||
|
@ -168,7 +168,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||||
eps=adam_cfg.adam_eps,
|
eps=adam_cfg.adam_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not gpc.config.parallel.use_fsdp:
|
if not gpc.config.parallel.zero1.fsdp:
|
||||||
optimizer = HybridZeroOptimizer(
|
optimizer = HybridZeroOptimizer(
|
||||||
naive_optimizer,
|
naive_optimizer,
|
||||||
grad_scal_cfg=gpc.config.grad_scaler,
|
grad_scal_cfg=gpc.config.grad_scaler,
|
||||||
|
|
|
@ -265,7 +265,7 @@ def save_model_checkpoint(folder, model):
|
||||||
model: The model to be saved
|
model: The model to be saved
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
states = get_shard_state_dict(model)
|
states = get_shard_state_dict(model)
|
||||||
else:
|
else:
|
||||||
states = model.state_dict()
|
states = model.state_dict()
|
||||||
|
@ -285,18 +285,18 @@ def save_model_checkpoint(folder, model):
|
||||||
# even if pp is not considered, it will definitely not be written on the same machine.
|
# even if pp is not considered, it will definitely not be written on the same machine.
|
||||||
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
||||||
for i in range(tp_size):
|
for i in range(tp_size):
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
for j in range(dp_size):
|
for j in range(dp_size):
|
||||||
should_save_rank_pair.add((i, j))
|
should_save_rank_pair.add((i, j))
|
||||||
else:
|
else:
|
||||||
should_save_rank_pair.add((i, i % dp_size))
|
should_save_rank_pair.add((i, i % dp_size))
|
||||||
|
|
||||||
if (tp_rank, dp_rank) in should_save_rank_pair:
|
if (tp_rank, dp_rank) in should_save_rank_pair:
|
||||||
f_dp = f"_dp{dp_rank}" if gpc.config.parallel.use_fsdp else ""
|
f_dp = f"_dp{dp_rank}" if gpc.config.parallel.zero1.fsdp else ""
|
||||||
fn = f"model_tp{tp_rank}_pp{pp_rank}{f_dp}.pt"
|
fn = f"model_tp{tp_rank}_pp{pp_rank}{f_dp}.pt"
|
||||||
fp = os.path.join(folder, fn)
|
fp = os.path.join(folder, fn)
|
||||||
llm_save(fp, saved_obj=states)
|
llm_save(fp, saved_obj=states)
|
||||||
if not gpc.config.parallel.use_fsdp or dp_rank == tp_rank % dp_size:
|
if not gpc.config.parallel.zero1.fsdp or dp_rank == tp_rank % dp_size:
|
||||||
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
|
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
|
||||||
topo_fp = os.path.join(folder, topo_fn)
|
topo_fp = os.path.join(folder, topo_fn)
|
||||||
llm_save(topo_fp, saved_obj=topo)
|
llm_save(topo_fp, saved_obj=topo)
|
||||||
|
@ -338,15 +338,15 @@ def load_model_checkpoint(folder, model):
|
||||||
|
|
||||||
# avoid ckpt misuse between FSDP and no-FSDP
|
# avoid ckpt misuse between FSDP and no-FSDP
|
||||||
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
|
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
|
||||||
assert ("_dp" in test_fn and gpc.config.parallel.use_fsdp) or (
|
assert ("_dp" in test_fn and gpc.config.parallel.zero1.fsdp) or (
|
||||||
"_dp" not in test_fn and not gpc.config.parallel.use_fsdp
|
"_dp" not in test_fn and not gpc.config.parallel.zero1.fsdp
|
||||||
), "FSDP model wants to load no-FSDP ckpts or reverse"
|
), "FSDP model wants to load no-FSDP ckpts or reverse"
|
||||||
|
|
||||||
max_pp, max_tp, max_zo = 0, 0, 0
|
max_pp, max_tp, max_zo = 0, 0, 0
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
if fn.startswith("model_t") and not fn.endswith(".md5"):
|
if fn.startswith("model_t") and not fn.endswith(".md5"):
|
||||||
segements = os.path.splitext(fn)[0].split("_")
|
segements = os.path.splitext(fn)[0].split("_")
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
max_zo = max(max_zo, int(segements[-1][2:]))
|
max_zo = max(max_zo, int(segements[-1][2:]))
|
||||||
max_pp = max(max_pp, int(segements[-2][2:]))
|
max_pp = max(max_pp, int(segements[-2][2:]))
|
||||||
max_tp = max(max_tp, int(segements[-3][2:]))
|
max_tp = max(max_tp, int(segements[-3][2:]))
|
||||||
|
@ -360,12 +360,12 @@ def load_model_checkpoint(folder, model):
|
||||||
assert (
|
assert (
|
||||||
tp_size == max_tp + 1
|
tp_size == max_tp + 1
|
||||||
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
assert (
|
assert (
|
||||||
dp_size == max_zo + 1
|
dp_size == max_zo + 1
|
||||||
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"
|
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"
|
||||||
|
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
|
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
|
||||||
else:
|
else:
|
||||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||||
|
@ -388,7 +388,7 @@ def load_model_checkpoint(folder, model):
|
||||||
# try to load expert parameter to separate files if model have moe layer
|
# try to load expert parameter to separate files if model have moe layer
|
||||||
try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank)
|
try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank)
|
||||||
|
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
|
missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False)
|
||||||
else:
|
else:
|
||||||
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
||||||
|
|
Loading…
Reference in New Issue