mirror of https://github.com/InternLM/InternLM
add support for FSDP with tp
parent
80f1eb9a36
commit
7d52276c13
|
@ -63,15 +63,14 @@ def args_sanity_check():
|
|||
gpc.config.parallel._add_item("tensor", 1)
|
||||
|
||||
if isinstance(gpc.config.parallel.pipeline, int):
|
||||
pp = gpc.config.parallel.pipeline
|
||||
pp = gpc.config.parallel.pipelines
|
||||
else:
|
||||
pp = gpc.config.parallel.pipeline.size
|
||||
tp = gpc.config.parallel.tensor
|
||||
|
||||
if "use_fsdp" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("use_fsdp", False)
|
||||
elif gpc.config.parallel.use_fsdp and (pp > 1 or tp > 1):
|
||||
logger.warning("FSDP not support when pipeline/tensor parallel is enabled, auto-close FSDP")
|
||||
elif gpc.config.parallel.use_fsdp and pp > 1:
|
||||
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
|
||||
gpc.config.parallel._add_item("use_fsdp", False)
|
||||
|
||||
# processing the data config in gpc
|
||||
|
|
|
@ -50,9 +50,6 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
self._clip_grad_norm = zero_cfg.clip_grad_norm
|
||||
self.use_fsdp = gpc.config.parallel.use_fsdp
|
||||
|
||||
# mark whether a module is part of TP or not
|
||||
# TODO: is_tensor_parallel_dict = dict()
|
||||
|
||||
# fp16 and fp32 params
|
||||
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
|
||||
self._fp16_param_groups = dict()
|
||||
|
@ -66,7 +63,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
self._fp16_param_groups[group_idx] = group_params
|
||||
|
||||
# create copy of fp32 weight
|
||||
fp32_tensor_param = [param.data.float().requires_grad_(True) for param in group_params]
|
||||
fp32_tensor_param = [param.data.float() for param in group_params]
|
||||
self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param
|
||||
|
||||
# replace
|
||||
|
@ -81,8 +78,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
def _compute_norm_with_fsdp_flatten(self, group_id):
|
||||
params = self._fp16_param_groups[group_id]
|
||||
gradients = [p.grad for p in params]
|
||||
params = [p for p in self._fp16_param_groups[group_id] if p.storage().size() != 0]
|
||||
gradients = [p.grad for p in params if p.storage().size() != 0]
|
||||
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
|
||||
|
||||
return norm_group
|
||||
|
@ -99,7 +96,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
for group_idx in range(len(self.param_groups)):
|
||||
params = self._fp16_param_groups[group_idx]
|
||||
for param in params:
|
||||
if param.requires_grad:
|
||||
if param.requires_grad and param.grad is not None:
|
||||
handle = reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP)
|
||||
handle.wait()
|
||||
|
||||
|
@ -131,11 +128,12 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
# create gradient for fp32 params
|
||||
for group_idx in range(len(self.param_groups)):
|
||||
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
|
||||
fp16_params = self._fp16_param_groups[group_idx]
|
||||
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0]
|
||||
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]
|
||||
|
||||
device = self._fp32_param_tensor_groups[group_idx][0].device
|
||||
for p, g in zip(self._fp32_param_tensor_groups[group_idx], grad_fp32):
|
||||
nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0]
|
||||
for p, g in zip(nonzero_fp32, grad_fp32):
|
||||
p.grad = g.to(device)
|
||||
|
||||
# unscale
|
||||
|
@ -145,8 +143,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
self.zero_grad()
|
||||
|
||||
for group_idx in range(len(self._fp16_param_groups)):
|
||||
fp16_params = self._fp16_param_groups[group_idx]
|
||||
fp32_tensor_params = self._fp32_param_tensor_groups[group_idx]
|
||||
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0]
|
||||
fp32_tensor_params = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0]
|
||||
# release fp32 grad
|
||||
release_param_grad(fp32_tensor_params)
|
||||
# update fp16 param
|
||||
|
@ -177,9 +175,10 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
if clip > 1.0:
|
||||
combined_scale_groups[group_id] = clip * loss_scale
|
||||
|
||||
for group_id, grads in self._fp32_param_tensor_groups.items():
|
||||
for g in grads:
|
||||
g.grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
||||
for group_id, param in self._fp32_param_tensor_groups.items():
|
||||
for p in param:
|
||||
if p.storage().size() != 0:
|
||||
p.grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
||||
|
||||
def state_dict(self):
|
||||
states = {}
|
||||
|
|
|
@ -17,7 +17,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
|||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context.random import set_mode
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
|
@ -36,16 +36,8 @@ from internlm.model.modeling_internlm import (
|
|||
PackedFlashBaseLayer1D,
|
||||
PackedFlashInternLm1D,
|
||||
)
|
||||
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from flash_attn.modules.mha import (
|
||||
CrossAttention,
|
||||
FlashCrossAttention,
|
||||
FlashSelfAttention,
|
||||
SelfAttention,
|
||||
_update_kv_cache,
|
||||
)
|
||||
|
||||
from internlm.model.utils import try_import_RMSNorm
|
||||
from internlm.monitor import send_heartbeat, set_env_var
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
|
@ -117,18 +109,23 @@ def initialize_model():
|
|||
|
||||
|
||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
# pre-save info for tensor parallel
|
||||
tp_dict = dict()
|
||||
for name, param in model.named_parameters():
|
||||
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
|
||||
tp_dict.update({name.replace("model.", ""): True})
|
||||
else:
|
||||
tp_dict.update({name.replace("model.", ""): False})
|
||||
|
||||
# set wrap_policy for fsdp wrap
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy, transformer_layer_cls={
|
||||
PackedFlashBaseLayer1D,
|
||||
PackedFlashInternLm1D,
|
||||
MHA,
|
||||
FlashCrossAttention,
|
||||
FlashSelfAttention,
|
||||
RMSNorm}
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D, MHA, RMSNorm},
|
||||
)
|
||||
|
||||
# wrap the model
|
||||
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||
model = FSDP(
|
||||
module=model,
|
||||
|
@ -138,8 +135,14 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
|||
forward_prefetch=True,
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
)
|
||||
|
||||
# re-set attribute for fsdp module
|
||||
for (name, param), pre in zip(model.named_parameters(), tp_dict):
|
||||
if pre in name and tp_dict[pre]:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -171,8 +171,8 @@ def get_shard_state_dict(shard_model):
|
|||
# with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
|
||||
# states = model.state_dict()
|
||||
|
||||
# in this version, FSDP model can only save with sharded shape
|
||||
with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT):
|
||||
# in this version, FSDP model can only save with sharded shapeLOCAL_STATE_DICT
|
||||
with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT):
|
||||
shard_states = shard_model.state_dict()
|
||||
|
||||
return shard_states
|
||||
|
@ -184,7 +184,7 @@ def load_shard_state_dict(shard_model, shard_state, **kwargs):
|
|||
|
||||
"""
|
||||
|
||||
with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT):
|
||||
with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT):
|
||||
missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs)
|
||||
|
||||
return (missing_k, unexpected_keys)
|
||||
|
|
Loading…
Reference in New Issue