mirror of https://github.com/InternLM/InternLM
fix format problem
parent
aedd88e5a7
commit
9b1b0c5c20
|
@ -11,9 +11,9 @@ from .process_group_initializer import (
|
|||
Initializer_Pipeline,
|
||||
Initializer_Tensor,
|
||||
Initializer_Zero1,
|
||||
Initializer_Zero3_dp,
|
||||
ParallelMode,
|
||||
ProcessGroupInitializer,
|
||||
Initializer_Zero3_dp,
|
||||
)
|
||||
from .random import (
|
||||
add_seed,
|
||||
|
|
|
@ -478,7 +478,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
||||
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
||||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||
if self.config.parallel.use_fsdp:
|
||||
if self.config.parallel.get("use_fsdp", False):
|
||||
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
|
||||
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
||||
if self.pipeline_parallel_size > 1:
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer, reload_zero_fp32_buff
|
||||
from .hybrid_zero_optim import (
|
||||
FSDPadaptOptimizer,
|
||||
HybridZeroOptimizer,
|
||||
reload_zero_fp32_buff,
|
||||
)
|
||||
|
||||
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"]
|
||||
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"]
|
||||
|
|
|
@ -82,13 +82,13 @@ class BaseOptimizer(Optimizer):
|
|||
|
||||
|
||||
class FSDPadaptOptimizer(BaseOptimizer):
|
||||
'''
|
||||
"""
|
||||
optimizer for Pytorch FSDP if 'use_fsdp' is True in config file
|
||||
reserve some necessary components of hybird-optim:
|
||||
grad_scaler;
|
||||
grad_clip and unscale;
|
||||
state_dict and load_state_dict
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -146,11 +146,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
def _compute_norm_with_fsdp_flatten(self, group_id):
|
||||
params = self._fp16_param_groups[group_id]
|
||||
gradients = [p.grad for p in params]
|
||||
norm_group = compute_norm(
|
||||
gradients=gradients,
|
||||
parameters=params,
|
||||
last_stage=True
|
||||
)
|
||||
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
|
||||
|
||||
return norm_group
|
||||
|
||||
|
@ -178,7 +174,6 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
norm_group = self._compute_norm_with_fsdp_flatten(group_idx)
|
||||
if norm_group == -1:
|
||||
found_inf = True
|
||||
break
|
||||
norm_groups[group_name] = norm_group
|
||||
|
||||
loss_scale = float(self.loss_scale.item()) # backup
|
||||
|
@ -187,7 +182,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
if gpc.is_rank_for_log():
|
||||
logger.warning("Overflow occurs, please check it.")
|
||||
self.zero_grad()
|
||||
return False, None
|
||||
return False, norm_groups
|
||||
|
||||
# get the global norm
|
||||
global_norm_groups = {}
|
||||
|
@ -211,10 +206,12 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
self.optim.step()
|
||||
self.zero_grad()
|
||||
|
||||
# update fp16 param
|
||||
for group_idx in range(len(self._fp16_param_groups)):
|
||||
fp16_params = self._fp16_param_groups[group_idx]
|
||||
fp32_tensor_params = self._fp32_param_tensor_groups[group_idx]
|
||||
# release fp32 grad
|
||||
release_param_grad(fp32_tensor_params)
|
||||
# update fp16 param
|
||||
for p, q in zip(fp16_params, fp32_tensor_params):
|
||||
p.data.copy_(q)
|
||||
|
||||
|
@ -272,8 +269,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups)
|
||||
for group_idx, param in flat_fp32_weights.items():
|
||||
self_param = self._fp32_param_tensor_groups[group_idx]
|
||||
assert (
|
||||
len(self_param) == len(param)
|
||||
assert len(self_param) == len(
|
||||
param
|
||||
), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}"
|
||||
for p, q in zip(self_param, param):
|
||||
p.data.copy_(q.data)
|
||||
|
|
|
@ -6,7 +6,7 @@ from .training_internlm import (
|
|||
initialize_optimizer,
|
||||
load_new_batch,
|
||||
record_current_batch_training_metrics,
|
||||
warp_FSDP_model,
|
||||
wrap_FSDP_model,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
@ -17,5 +17,5 @@ __all__ = [
|
|||
"initialize_optimizer",
|
||||
"load_new_batch",
|
||||
"record_current_batch_training_metrics",
|
||||
"warp_FSDP_model",
|
||||
"wrap_FSDP_model",
|
||||
]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, Union
|
||||
|
@ -8,6 +9,12 @@ from typing import Callable, Iterable, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
BackwardPrefetch,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
|
@ -25,11 +32,15 @@ from internlm.data.packed_dataset import (
|
|||
get_packed_dataset_without_short_length,
|
||||
)
|
||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.model.modeling_internlm import (
|
||||
PackedFlashBaseLayer1D,
|
||||
PackedFlashInternLm1D,
|
||||
)
|
||||
from internlm.monitor import send_heartbeat, set_env_var
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer, FSDPadaptOptimizer
|
||||
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
|
||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||
from internlm.utils.common import DummyProfile
|
||||
from internlm.utils.logger import get_logger
|
||||
|
@ -42,17 +53,6 @@ from internlm.utils.parallel import (
|
|||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
ShardingStrategy,
|
||||
MixedPrecision,
|
||||
BackwardPrefetch,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
import functools
|
||||
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
|
||||
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
|
@ -103,19 +103,20 @@ def initialize_model():
|
|||
return model
|
||||
|
||||
|
||||
def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
||||
transformer_auto_wrap_policy, transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
||||
)
|
||||
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||
model = FSDP(module=model,
|
||||
process_group=grp,
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||
auto_wrap_policy=transformer_wrap_policy,
|
||||
forward_prefetch=True,
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||
model = FSDP(
|
||||
module=model,
|
||||
process_group=grp,
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||
auto_wrap_policy=transformer_wrap_policy,
|
||||
forward_prefetch=True,
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||
limit_all_gathers=True,
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
|
@ -11,6 +11,9 @@ from enum import Enum
|
|||
from typing import Callable, Dict, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
@ -162,13 +165,10 @@ def get_state_dict(model):
|
|||
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
|
||||
|
||||
"""
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
||||
|
||||
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
|
||||
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
||||
with FSDP.state_dict_type(
|
||||
model, StateDictType.FULL_STATE_DICT, save_policy):
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
|
||||
states = model.state_dict()
|
||||
|
||||
return states
|
||||
|
|
6
train.py
6
train.py
|
@ -27,7 +27,7 @@ from internlm.train import (
|
|||
initialize_optimizer,
|
||||
load_new_batch,
|
||||
record_current_batch_training_metrics,
|
||||
warp_FSDP_model,
|
||||
wrap_FSDP_model,
|
||||
)
|
||||
from internlm.utils.common import (
|
||||
BatchSkipper,
|
||||
|
@ -111,8 +111,8 @@ def main(args):
|
|||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||
|
||||
# if fsdp enabled, warp the model
|
||||
model = warp_FSDP_model(model)
|
||||
# if fsdp enabled, warp the model
|
||||
model = wrap_FSDP_model(model)
|
||||
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||
|
||||
|
|
Loading…
Reference in New Issue