fix format problem

pull/293/head
zaglc 2023-09-14 17:03:36 +08:00
parent aedd88e5a7
commit 9b1b0c5c20
8 changed files with 48 additions and 46 deletions

View File

@ -11,9 +11,9 @@ from .process_group_initializer import (
Initializer_Pipeline,
Initializer_Tensor,
Initializer_Zero1,
Initializer_Zero3_dp,
ParallelMode,
ProcessGroupInitializer,
Initializer_Zero3_dp,
)
from .random import (
add_seed,

View File

@ -478,7 +478,7 @@ class ParallelContext(metaclass=SingletonMeta):
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if self.config.parallel.use_fsdp:
if self.config.parallel.get("use_fsdp", False):
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
if self.pipeline_parallel_size > 1:

View File

@ -1,6 +1,10 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer, reload_zero_fp32_buff
from .hybrid_zero_optim import (
FSDPadaptOptimizer,
HybridZeroOptimizer,
reload_zero_fp32_buff,
)
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"]
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"]

View File

@ -82,13 +82,13 @@ class BaseOptimizer(Optimizer):
class FSDPadaptOptimizer(BaseOptimizer):
'''
"""
optimizer for Pytorch FSDP if 'use_fsdp' is True in config file
reserve some necessary components of hybird-optim:
grad_scaler;
grad_clip and unscale;
state_dict and load_state_dict
'''
"""
def __init__(
self,
@ -146,11 +146,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
def _compute_norm_with_fsdp_flatten(self, group_id):
params = self._fp16_param_groups[group_id]
gradients = [p.grad for p in params]
norm_group = compute_norm(
gradients=gradients,
parameters=params,
last_stage=True
)
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
return norm_group
@ -178,7 +174,6 @@ class FSDPadaptOptimizer(BaseOptimizer):
norm_group = self._compute_norm_with_fsdp_flatten(group_idx)
if norm_group == -1:
found_inf = True
break
norm_groups[group_name] = norm_group
loss_scale = float(self.loss_scale.item()) # backup
@ -187,7 +182,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
self.zero_grad()
return False, None
return False, norm_groups
# get the global norm
global_norm_groups = {}
@ -211,10 +206,12 @@ class FSDPadaptOptimizer(BaseOptimizer):
self.optim.step()
self.zero_grad()
# update fp16 param
for group_idx in range(len(self._fp16_param_groups)):
fp16_params = self._fp16_param_groups[group_idx]
fp32_tensor_params = self._fp32_param_tensor_groups[group_idx]
# release fp32 grad
release_param_grad(fp32_tensor_params)
# update fp16 param
for p, q in zip(fp16_params, fp32_tensor_params):
p.data.copy_(q)
@ -272,8 +269,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups)
for group_idx, param in flat_fp32_weights.items():
self_param = self._fp32_param_tensor_groups[group_idx]
assert (
len(self_param) == len(param)
assert len(self_param) == len(
param
), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}"
for p, q in zip(self_param, param):
p.data.copy_(q.data)

View File

@ -6,7 +6,7 @@ from .training_internlm import (
initialize_optimizer,
load_new_batch,
record_current_batch_training_metrics,
warp_FSDP_model,
wrap_FSDP_model,
)
__all__ = [
@ -17,5 +17,5 @@ __all__ = [
"initialize_optimizer",
"load_new_batch",
"record_current_batch_training_metrics",
"warp_FSDP_model",
"wrap_FSDP_model",
]

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import functools
import time
from functools import partial
from typing import Callable, Iterable, Union
@ -8,6 +9,12 @@ from typing import Callable, Iterable, Union
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import ConcatDataset, DataLoader
from internlm.core.context import ParallelMode
@ -25,11 +32,15 @@ from internlm.data.packed_dataset import (
get_packed_dataset_without_short_length,
)
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.modeling_internlm import (
PackedFlashBaseLayer1D,
PackedFlashInternLm1D,
)
from internlm.monitor import send_heartbeat, set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer, FSDPadaptOptimizer
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger
@ -42,17 +53,6 @@ from internlm.utils.parallel import (
from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.timeout import llm_timeout
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
ShardingStrategy,
MixedPrecision,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
logger = get_logger(__file__)
@ -103,19 +103,20 @@ def initialize_model():
return model
def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
if gpc.config.parallel.use_fsdp:
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
transformer_auto_wrap_policy, transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
)
grp = gpc.get_group(ParallelMode.ZERO1)
model = FSDP(module=model,
process_group=grp,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=transformer_wrap_policy,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
model = FSDP(
module=model,
process_group=grp,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=transformer_wrap_policy,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
)
return model

View File

@ -11,6 +11,9 @@ from enum import Enum
from typing import Callable, Dict, Union
import torch
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
@ -162,13 +165,10 @@ def get_state_dict(model):
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
"""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy):
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
states = model.state_dict()
return states

View File

@ -27,7 +27,7 @@ from internlm.train import (
initialize_optimizer,
load_new_batch,
record_current_batch_training_metrics,
warp_FSDP_model,
wrap_FSDP_model,
)
from internlm.utils.common import (
BatchSkipper,
@ -111,8 +111,8 @@ def main(args):
# initialize and resume train state
train_state = TrainState(gpc.config, train_dl.batch_sampler)
# if fsdp enabled, warp the model
model = warp_FSDP_model(model)
# if fsdp enabled, warp the model
model = wrap_FSDP_model(model)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)