add support for FSDP with tp

pull/293/head
zaglc 2023-10-08 15:33:31 +08:00
parent 80f1eb9a36
commit 7d52276c13
5 changed files with 42 additions and 39 deletions

View File

@ -63,15 +63,14 @@ def args_sanity_check():
gpc.config.parallel._add_item("tensor", 1)
if isinstance(gpc.config.parallel.pipeline, int):
pp = gpc.config.parallel.pipeline
pp = gpc.config.parallel.pipelines
else:
pp = gpc.config.parallel.pipeline.size
tp = gpc.config.parallel.tensor
if "use_fsdp" not in gpc.config.parallel:
gpc.config.parallel._add_item("use_fsdp", False)
elif gpc.config.parallel.use_fsdp and (pp > 1 or tp > 1):
logger.warning("FSDP not support when pipeline/tensor parallel is enabled, auto-close FSDP")
elif gpc.config.parallel.use_fsdp and pp > 1:
logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP")
gpc.config.parallel._add_item("use_fsdp", False)
# processing the data config in gpc

View File

@ -50,9 +50,6 @@ class FSDPadaptOptimizer(BaseOptimizer):
self._clip_grad_norm = zero_cfg.clip_grad_norm
self.use_fsdp = gpc.config.parallel.use_fsdp
# mark whether a module is part of TP or not
# TODO: is_tensor_parallel_dict = dict()
# fp16 and fp32 params
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
self._fp16_param_groups = dict()
@ -66,7 +63,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
self._fp16_param_groups[group_idx] = group_params
# create copy of fp32 weight
fp32_tensor_param = [param.data.float().requires_grad_(True) for param in group_params]
fp32_tensor_param = [param.data.float() for param in group_params]
self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param
# replace
@ -81,8 +78,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
loss.backward(retain_graph=retain_graph)
def _compute_norm_with_fsdp_flatten(self, group_id):
params = self._fp16_param_groups[group_id]
gradients = [p.grad for p in params]
params = [p for p in self._fp16_param_groups[group_id] if p.storage().size() != 0]
gradients = [p.grad for p in params if p.storage().size() != 0]
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
return norm_group
@ -99,7 +96,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
for group_idx in range(len(self.param_groups)):
params = self._fp16_param_groups[group_idx]
for param in params:
if param.requires_grad:
if param.requires_grad and param.grad is not None:
handle = reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP)
handle.wait()
@ -131,11 +128,12 @@ class FSDPadaptOptimizer(BaseOptimizer):
# create gradient for fp32 params
for group_idx in range(len(self.param_groups)):
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
fp16_params = self._fp16_param_groups[group_idx]
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0]
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]
device = self._fp32_param_tensor_groups[group_idx][0].device
for p, g in zip(self._fp32_param_tensor_groups[group_idx], grad_fp32):
nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0]
for p, g in zip(nonzero_fp32, grad_fp32):
p.grad = g.to(device)
# unscale
@ -145,8 +143,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
self.zero_grad()
for group_idx in range(len(self._fp16_param_groups)):
fp16_params = self._fp16_param_groups[group_idx]
fp32_tensor_params = self._fp32_param_tensor_groups[group_idx]
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0]
fp32_tensor_params = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0]
# release fp32 grad
release_param_grad(fp32_tensor_params)
# update fp16 param
@ -177,9 +175,10 @@ class FSDPadaptOptimizer(BaseOptimizer):
if clip > 1.0:
combined_scale_groups[group_id] = clip * loss_scale
for group_id, grads in self._fp32_param_tensor_groups.items():
for g in grads:
g.grad.data.mul_(1.0 / combined_scale_groups[group_id])
for group_id, param in self._fp32_param_tensor_groups.items():
for p in param:
if p.storage().size() != 0:
p.grad.data.mul_(1.0 / combined_scale_groups[group_id])
def state_dict(self):
states = {}

View File

@ -17,7 +17,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import ConcatDataset, DataLoader
from internlm.core.context import ParallelMode
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.random import set_mode
from internlm.core.naive_amp import NaiveAMPModel
@ -36,16 +36,8 @@ from internlm.model.modeling_internlm import (
PackedFlashBaseLayer1D,
PackedFlashInternLm1D,
)
from internlm.model.multi_head_attention import MHA
from flash_attn.modules.mha import (
CrossAttention,
FlashCrossAttention,
FlashSelfAttention,
SelfAttention,
_update_kv_cache,
)
from internlm.model.utils import try_import_RMSNorm
from internlm.monitor import send_heartbeat, set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
@ -117,18 +109,23 @@ def initialize_model():
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
RMSNorm = try_import_RMSNorm()
if gpc.config.parallel.use_fsdp:
# pre-save info for tensor parallel
tp_dict = dict()
for name, param in model.named_parameters():
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
tp_dict.update({name.replace("model.", ""): True})
else:
tp_dict.update({name.replace("model.", ""): False})
# set wrap_policy for fsdp wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls={
PackedFlashBaseLayer1D,
PackedFlashInternLm1D,
MHA,
FlashCrossAttention,
FlashSelfAttention,
RMSNorm}
transformer_auto_wrap_policy,
transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D, MHA, RMSNorm},
)
# wrap the model
grp = gpc.get_group(ParallelMode.ZERO1)
model = FSDP(
module=model,
@ -138,8 +135,14 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
use_orig_params=True,
)
# re-set attribute for fsdp module
for (name, param), pre in zip(model.named_parameters(), tp_dict):
if pre in name and tp_dict[pre]:
setattr(param, IS_TENSOR_PARALLEL, True)
return model

View File

@ -171,8 +171,8 @@ def get_shard_state_dict(shard_model):
# with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
# states = model.state_dict()
# in this version, FSDP model can only save with sharded shape
with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT):
# in this version, FSDP model can only save with sharded shapeLOCAL_STATE_DICT
with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT):
shard_states = shard_model.state_dict()
return shard_states
@ -184,7 +184,7 @@ def load_shard_state_dict(shard_model, shard_state, **kwargs):
"""
with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT):
with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT):
missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs)
return (missing_k, unexpected_keys)

View File

@ -289,6 +289,8 @@ def main(args):
if __name__ == "__main__":
assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}"
args = parse_args()
hostname = socket.gethostname()