mirror of https://github.com/InternLM/InternLM
fix failure in lint-check
parent
31d2a2916d
commit
1d60f90ed9
|
@ -392,7 +392,7 @@ class Initializer_Nettest(ProcessGroupInitializer):
|
||||||
ranks_in_group = ranks
|
ranks_in_group = ranks
|
||||||
|
|
||||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||||
|
|
||||||
|
|
||||||
class Initializer_Zero3_dp(ProcessGroupInitializer):
|
class Initializer_Zero3_dp(ProcessGroupInitializer):
|
||||||
"""A ProcessGroupInitializer for data parallelism.
|
"""A ProcessGroupInitializer for data parallelism.
|
||||||
|
@ -421,7 +421,6 @@ class Initializer_Zero3_dp(ProcessGroupInitializer):
|
||||||
|
|
||||||
assert self.world_size % self.data_parallel_size == 0
|
assert self.world_size % self.data_parallel_size == 0
|
||||||
|
|
||||||
|
|
||||||
def init_dist_group(self, use_cpu: bool = False):
|
def init_dist_group(self, use_cpu: bool = False):
|
||||||
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
|
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
|
||||||
|
|
||||||
|
|
|
@ -90,13 +90,13 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
grad_scal_cfg: Config = None,
|
grad_scal_cfg: Config = None,
|
||||||
zero_cfg: Config = None,
|
zero_cfg: Config = None,
|
||||||
):
|
):
|
||||||
super().__init__(optim=optimizer)
|
super().__init__(optim=optimizer)
|
||||||
|
|
||||||
# gradient scaler
|
# gradient scaler
|
||||||
self.grad_scaler = DynamicGradScaler(
|
self.grad_scaler = DynamicGradScaler(
|
||||||
initial_scale=grad_scal_cfg.fp16.initial_scale,
|
initial_scale=grad_scal_cfg.fp16.initial_scale,
|
||||||
|
@ -113,7 +113,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
self.use_fsdp = gpc.config.parallel.use_fsdp
|
self.use_fsdp = gpc.config.parallel.use_fsdp
|
||||||
|
|
||||||
# mark whether a module is part of TP or not
|
# mark whether a module is part of TP or not
|
||||||
is_tensor_parallel_dict = dict()
|
# TODO: is_tensor_parallel_dict = dict()
|
||||||
|
|
||||||
# fp16 and fp32 params
|
# fp16 and fp32 params
|
||||||
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
|
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
|
||||||
|
@ -150,7 +150,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
parameters=params,
|
parameters=params,
|
||||||
last_stage=True
|
last_stage=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return norm_group
|
return norm_group
|
||||||
|
|
||||||
def zero_grad(self):
|
def zero_grad(self):
|
||||||
|
@ -187,12 +187,12 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
logger.warning("Overflow occurs, please check it.")
|
logger.warning("Overflow occurs, please check it.")
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
# get the global norm
|
# get the global norm
|
||||||
global_norm_groups = {}
|
global_norm_groups = {}
|
||||||
if self._clip_grad_norm > 0:
|
if self._clip_grad_norm > 0:
|
||||||
for group_name, norm in norm_groups.items():
|
for group_name, norm in norm_groups.items():
|
||||||
global_norm_groups[group_name] = norm**0.5
|
global_norm_groups[group_name] = norm**0.5
|
||||||
|
|
||||||
# create gradient for fp32 params
|
# create gradient for fp32 params
|
||||||
for group_idx in range(len(self.param_groups)):
|
for group_idx in range(len(self.param_groups)):
|
||||||
|
@ -207,7 +207,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
# unscale
|
# unscale
|
||||||
self._unscale_and_clip_grads(list(global_norm_groups.values()), loss_scale)
|
self._unscale_and_clip_grads(list(global_norm_groups.values()), loss_scale)
|
||||||
|
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
|
|
||||||
# update fp16 param
|
# update fp16 param
|
||||||
|
@ -221,7 +221,6 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
global_norm_groups[group_name] = global_norm / loss_scale
|
global_norm_groups[group_name] = global_norm / loss_scale
|
||||||
return True, global_norm_groups
|
return True, global_norm_groups
|
||||||
|
|
||||||
|
|
||||||
def clip_grad_norm(self, model, max_norm):
|
def clip_grad_norm(self, model, max_norm):
|
||||||
# will conduct in the step()
|
# will conduct in the step()
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -42,18 +42,11 @@ from internlm.utils.registry import MODEL_INITIALIZER
|
||||||
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
CPUOffload,
|
|
||||||
BackwardPrefetch,
|
|
||||||
ShardingStrategy,
|
ShardingStrategy,
|
||||||
MixedPrecision,
|
MixedPrecision,
|
||||||
BackwardPrefetch,
|
BackwardPrefetch,
|
||||||
)
|
)
|
||||||
from torch.distributed.fsdp.wrap import (
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||||
size_based_auto_wrap_policy,
|
|
||||||
transformer_auto_wrap_policy,
|
|
||||||
enable_wrap,
|
|
||||||
wrap,
|
|
||||||
)
|
|
||||||
import functools
|
import functools
|
||||||
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
|
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
|
||||||
|
|
||||||
|
@ -107,23 +100,17 @@ def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
transformer_wrap_policy = functools.partial(
|
transformer_wrap_policy = functools.partial(
|
||||||
transformer_auto_wrap_policy,
|
transformer_auto_wrap_policy,
|
||||||
transformer_layer_cls = {PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
||||||
)
|
)
|
||||||
mx = MixedPrecision(
|
|
||||||
param_dtype=gpc.config.model.dtype, reduce_dtype=gpc.config.model.dtype,
|
|
||||||
buffer_dtype=gpc.config.model.dtype, keep_low_precision_grads=True)
|
|
||||||
grp = gpc.get_group(ParallelMode.ZERO1)
|
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||||
model = FSDP(module=model,
|
model = FSDP(module=model,
|
||||||
process_group=grp,
|
process_group=grp,
|
||||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
auto_wrap_policy=transformer_wrap_policy,
|
auto_wrap_policy=transformer_wrap_policy,
|
||||||
forward_prefetch=True,
|
forward_prefetch=True,
|
||||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||||
#cpu_offload=CPUOfload(offload_params=True)
|
)
|
||||||
#mixed_precision=mx,
|
|
||||||
#device_id=torch.cuda.current_device()
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -159,9 +146,9 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer = FSDPadaptOptimizer(
|
optimizer = FSDPadaptOptimizer(
|
||||||
naive_optimizer,
|
naive_optimizer,
|
||||||
grad_scal_cfg=gpc.config.grad_scaler,
|
grad_scal_cfg=gpc.config.grad_scaler,
|
||||||
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||||
|
|
|
@ -57,14 +57,14 @@ def get_model_topology(model):
|
||||||
|
|
||||||
def get_state_dict(model):
|
def get_state_dict(model):
|
||||||
"""
|
"""
|
||||||
Only used for FSDP module saving.
|
Only used for FSDP module saving.
|
||||||
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
|
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
|
||||||
(saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu.
|
(saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu.
|
||||||
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
|
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType# , FullOptimStateDictConfig
|
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
||||||
|
|
||||||
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
|
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
|
||||||
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
|
||||||
|
@ -91,9 +91,9 @@ def save_model_checkpoint(folder, model):
|
||||||
|
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
states = get_state_dict(model)
|
states = get_state_dict(model)
|
||||||
else:
|
else:
|
||||||
states = model.state_dict()
|
states = model.state_dict()
|
||||||
|
|
||||||
topo = get_model_topology(model)
|
topo = get_model_topology(model)
|
||||||
|
|
||||||
if folder is not None:
|
if folder is not None:
|
||||||
|
|
Loading…
Reference in New Issue