mirror of https://github.com/InternLM/InternLM
fix format problem
parent
aedd88e5a7
commit
9b1b0c5c20
|
@ -11,9 +11,9 @@ from .process_group_initializer import (
|
||||||
Initializer_Pipeline,
|
Initializer_Pipeline,
|
||||||
Initializer_Tensor,
|
Initializer_Tensor,
|
||||||
Initializer_Zero1,
|
Initializer_Zero1,
|
||||||
|
Initializer_Zero3_dp,
|
||||||
ParallelMode,
|
ParallelMode,
|
||||||
ProcessGroupInitializer,
|
ProcessGroupInitializer,
|
||||||
Initializer_Zero3_dp,
|
|
||||||
)
|
)
|
||||||
from .random import (
|
from .random import (
|
||||||
add_seed,
|
add_seed,
|
||||||
|
|
|
@ -478,7 +478,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
||||||
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
||||||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||||
if self.config.parallel.use_fsdp:
|
if self.config.parallel.get("use_fsdp", False):
|
||||||
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
|
||||||
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
||||||
if self.pipeline_parallel_size > 1:
|
if self.pipeline_parallel_size > 1:
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer, reload_zero_fp32_buff
|
from .hybrid_zero_optim import (
|
||||||
|
FSDPadaptOptimizer,
|
||||||
|
HybridZeroOptimizer,
|
||||||
|
reload_zero_fp32_buff,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"]
|
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"]
|
|
@ -82,13 +82,13 @@ class BaseOptimizer(Optimizer):
|
||||||
|
|
||||||
|
|
||||||
class FSDPadaptOptimizer(BaseOptimizer):
|
class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
'''
|
"""
|
||||||
optimizer for Pytorch FSDP if 'use_fsdp' is True in config file
|
optimizer for Pytorch FSDP if 'use_fsdp' is True in config file
|
||||||
reserve some necessary components of hybird-optim:
|
reserve some necessary components of hybird-optim:
|
||||||
grad_scaler;
|
grad_scaler;
|
||||||
grad_clip and unscale;
|
grad_clip and unscale;
|
||||||
state_dict and load_state_dict
|
state_dict and load_state_dict
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -146,11 +146,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
def _compute_norm_with_fsdp_flatten(self, group_id):
|
def _compute_norm_with_fsdp_flatten(self, group_id):
|
||||||
params = self._fp16_param_groups[group_id]
|
params = self._fp16_param_groups[group_id]
|
||||||
gradients = [p.grad for p in params]
|
gradients = [p.grad for p in params]
|
||||||
norm_group = compute_norm(
|
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
|
||||||
gradients=gradients,
|
|
||||||
parameters=params,
|
|
||||||
last_stage=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return norm_group
|
return norm_group
|
||||||
|
|
||||||
|
@ -178,7 +174,6 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
norm_group = self._compute_norm_with_fsdp_flatten(group_idx)
|
norm_group = self._compute_norm_with_fsdp_flatten(group_idx)
|
||||||
if norm_group == -1:
|
if norm_group == -1:
|
||||||
found_inf = True
|
found_inf = True
|
||||||
break
|
|
||||||
norm_groups[group_name] = norm_group
|
norm_groups[group_name] = norm_group
|
||||||
|
|
||||||
loss_scale = float(self.loss_scale.item()) # backup
|
loss_scale = float(self.loss_scale.item()) # backup
|
||||||
|
@ -187,7 +182,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.warning("Overflow occurs, please check it.")
|
logger.warning("Overflow occurs, please check it.")
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return False, None
|
return False, norm_groups
|
||||||
|
|
||||||
# get the global norm
|
# get the global norm
|
||||||
global_norm_groups = {}
|
global_norm_groups = {}
|
||||||
|
@ -211,10 +206,12 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
|
|
||||||
# update fp16 param
|
|
||||||
for group_idx in range(len(self._fp16_param_groups)):
|
for group_idx in range(len(self._fp16_param_groups)):
|
||||||
fp16_params = self._fp16_param_groups[group_idx]
|
fp16_params = self._fp16_param_groups[group_idx]
|
||||||
fp32_tensor_params = self._fp32_param_tensor_groups[group_idx]
|
fp32_tensor_params = self._fp32_param_tensor_groups[group_idx]
|
||||||
|
# release fp32 grad
|
||||||
|
release_param_grad(fp32_tensor_params)
|
||||||
|
# update fp16 param
|
||||||
for p, q in zip(fp16_params, fp32_tensor_params):
|
for p, q in zip(fp16_params, fp32_tensor_params):
|
||||||
p.data.copy_(q)
|
p.data.copy_(q)
|
||||||
|
|
||||||
|
@ -272,8 +269,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups)
|
assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups)
|
||||||
for group_idx, param in flat_fp32_weights.items():
|
for group_idx, param in flat_fp32_weights.items():
|
||||||
self_param = self._fp32_param_tensor_groups[group_idx]
|
self_param = self._fp32_param_tensor_groups[group_idx]
|
||||||
assert (
|
assert len(self_param) == len(
|
||||||
len(self_param) == len(param)
|
param
|
||||||
), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}"
|
), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}"
|
||||||
for p, q in zip(self_param, param):
|
for p, q in zip(self_param, param):
|
||||||
p.data.copy_(q.data)
|
p.data.copy_(q.data)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from .training_internlm import (
|
||||||
initialize_optimizer,
|
initialize_optimizer,
|
||||||
load_new_batch,
|
load_new_batch,
|
||||||
record_current_batch_training_metrics,
|
record_current_batch_training_metrics,
|
||||||
warp_FSDP_model,
|
wrap_FSDP_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -17,5 +17,5 @@ __all__ = [
|
||||||
"initialize_optimizer",
|
"initialize_optimizer",
|
||||||
"load_new_batch",
|
"load_new_batch",
|
||||||
"record_current_batch_training_metrics",
|
"record_current_batch_training_metrics",
|
||||||
"warp_FSDP_model",
|
"wrap_FSDP_model",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import functools
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Iterable, Union
|
from typing import Callable, Iterable, Union
|
||||||
|
@ -8,6 +9,12 @@ from typing import Callable, Iterable, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
|
BackwardPrefetch,
|
||||||
|
ShardingStrategy,
|
||||||
|
)
|
||||||
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||||
from torch.utils.data import ConcatDataset, DataLoader
|
from torch.utils.data import ConcatDataset, DataLoader
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
|
@ -25,11 +32,15 @@ from internlm.data.packed_dataset import (
|
||||||
get_packed_dataset_without_short_length,
|
get_packed_dataset_without_short_length,
|
||||||
)
|
)
|
||||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||||
|
from internlm.model.modeling_internlm import (
|
||||||
|
PackedFlashBaseLayer1D,
|
||||||
|
PackedFlashInternLm1D,
|
||||||
|
)
|
||||||
from internlm.monitor import send_heartbeat, set_env_var
|
from internlm.monitor import send_heartbeat, set_env_var
|
||||||
from internlm.monitor.monitor import monitor_manager as mm
|
from internlm.monitor.monitor import monitor_manager as mm
|
||||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||||
from internlm.solver.optimizer import HybridZeroOptimizer, FSDPadaptOptimizer
|
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
|
||||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||||
from internlm.utils.common import DummyProfile
|
from internlm.utils.common import DummyProfile
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
@ -42,17 +53,6 @@ from internlm.utils.parallel import (
|
||||||
from internlm.utils.registry import MODEL_INITIALIZER
|
from internlm.utils.registry import MODEL_INITIALIZER
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
|
||||||
ShardingStrategy,
|
|
||||||
MixedPrecision,
|
|
||||||
BackwardPrefetch,
|
|
||||||
)
|
|
||||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
|
||||||
import functools
|
|
||||||
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,19 +103,20 @@ def initialize_model():
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
transformer_wrap_policy = functools.partial(
|
transformer_wrap_policy = functools.partial(
|
||||||
transformer_auto_wrap_policy,
|
transformer_auto_wrap_policy, transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
||||||
transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
|
||||||
)
|
)
|
||||||
grp = gpc.get_group(ParallelMode.ZERO1)
|
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||||
model = FSDP(module=model,
|
model = FSDP(
|
||||||
|
module=model,
|
||||||
process_group=grp,
|
process_group=grp,
|
||||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
auto_wrap_policy=transformer_wrap_policy,
|
auto_wrap_policy=transformer_wrap_policy,
|
||||||
forward_prefetch=True,
|
forward_prefetch=True,
|
||||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||||
|
limit_all_gathers=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -11,6 +11,9 @@ from enum import Enum
|
||||||
from typing import Callable, Dict, Union
|
from typing import Callable, Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.distributed.fsdp import FullStateDictConfig
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.fsdp import StateDictType
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
@ -162,13 +165,10 @@ def get_state_dict(model):
|
||||||
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
|
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
|
||||||
|
|
||||||
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
|
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
|
||||||
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
||||||
with FSDP.state_dict_type(
|
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
|
||||||
model, StateDictType.FULL_STATE_DICT, save_policy):
|
|
||||||
states = model.state_dict()
|
states = model.state_dict()
|
||||||
|
|
||||||
return states
|
return states
|
||||||
|
|
4
train.py
4
train.py
|
@ -27,7 +27,7 @@ from internlm.train import (
|
||||||
initialize_optimizer,
|
initialize_optimizer,
|
||||||
load_new_batch,
|
load_new_batch,
|
||||||
record_current_batch_training_metrics,
|
record_current_batch_training_metrics,
|
||||||
warp_FSDP_model,
|
wrap_FSDP_model,
|
||||||
)
|
)
|
||||||
from internlm.utils.common import (
|
from internlm.utils.common import (
|
||||||
BatchSkipper,
|
BatchSkipper,
|
||||||
|
@ -112,7 +112,7 @@ def main(args):
|
||||||
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||||
|
|
||||||
# if fsdp enabled, warp the model
|
# if fsdp enabled, warp the model
|
||||||
model = warp_FSDP_model(model)
|
model = wrap_FSDP_model(model)
|
||||||
|
|
||||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue