mirror of https://github.com/InternLM/InternLM
fix failure in lint-check
parent
31d2a2916d
commit
1d60f90ed9
|
@ -421,7 +421,6 @@ class Initializer_Zero3_dp(ProcessGroupInitializer):
|
|||
|
||||
assert self.world_size % self.data_parallel_size == 0
|
||||
|
||||
|
||||
def init_dist_group(self, use_cpu: bool = False):
|
||||
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
|
|
|
@ -90,11 +90,11 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
grad_scal_cfg: Config = None,
|
||||
zero_cfg: Config = None,
|
||||
):
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
grad_scal_cfg: Config = None,
|
||||
zero_cfg: Config = None,
|
||||
):
|
||||
super().__init__(optim=optimizer)
|
||||
|
||||
# gradient scaler
|
||||
|
@ -113,7 +113,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
self.use_fsdp = gpc.config.parallel.use_fsdp
|
||||
|
||||
# mark whether a module is part of TP or not
|
||||
is_tensor_parallel_dict = dict()
|
||||
# TODO: is_tensor_parallel_dict = dict()
|
||||
|
||||
# fp16 and fp32 params
|
||||
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
|
||||
|
@ -221,7 +221,6 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
global_norm_groups[group_name] = global_norm / loss_scale
|
||||
return True, global_norm_groups
|
||||
|
||||
|
||||
def clip_grad_norm(self, model, max_norm):
|
||||
# will conduct in the step()
|
||||
pass
|
||||
|
|
|
@ -42,18 +42,11 @@ from internlm.utils.registry import MODEL_INITIALIZER
|
|||
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
CPUOffload,
|
||||
BackwardPrefetch,
|
||||
ShardingStrategy,
|
||||
MixedPrecision,
|
||||
BackwardPrefetch,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
enable_wrap,
|
||||
wrap,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
import functools
|
||||
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
|
||||
|
||||
|
@ -107,11 +100,8 @@ def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
|||
if gpc.config.parallel.use_fsdp:
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls = {PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
||||
transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
||||
)
|
||||
mx = MixedPrecision(
|
||||
param_dtype=gpc.config.model.dtype, reduce_dtype=gpc.config.model.dtype,
|
||||
buffer_dtype=gpc.config.model.dtype, keep_low_precision_grads=True)
|
||||
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||
model = FSDP(module=model,
|
||||
process_group=grp,
|
||||
|
@ -119,10 +109,7 @@ def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
|||
auto_wrap_policy=transformer_wrap_policy,
|
||||
forward_prefetch=True,
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||
#cpu_offload=CPUOfload(offload_params=True)
|
||||
#mixed_precision=mx,
|
||||
#device_id=torch.cuda.current_device()
|
||||
)
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ def get_state_dict(model):
|
|||
|
||||
"""
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType# , FullOptimStateDictConfig
|
||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
||||
|
||||
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
|
||||
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
||||
|
|
Loading…
Reference in New Issue