fix failure in lint-check

pull/293/head
zaglc 2023-09-08 13:19:42 +08:00
parent 31d2a2916d
commit 1d60f90ed9
4 changed files with 26 additions and 41 deletions

View File

@ -392,7 +392,7 @@ class Initializer_Nettest(ProcessGroupInitializer):
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_Zero3_dp(ProcessGroupInitializer):
"""A ProcessGroupInitializer for data parallelism.
@ -421,7 +421,6 @@ class Initializer_Zero3_dp(ProcessGroupInitializer):
assert self.world_size % self.data_parallel_size == 0
def init_dist_group(self, use_cpu: bool = False):
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.

View File

@ -90,13 +90,13 @@ class FSDPadaptOptimizer(BaseOptimizer):
'''
def __init__(
self,
optimizer: Optimizer,
grad_scal_cfg: Config = None,
zero_cfg: Config = None,
):
self,
optimizer: Optimizer,
grad_scal_cfg: Config = None,
zero_cfg: Config = None,
):
super().__init__(optim=optimizer)
# gradient scaler
self.grad_scaler = DynamicGradScaler(
initial_scale=grad_scal_cfg.fp16.initial_scale,
@ -113,7 +113,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
self.use_fsdp = gpc.config.parallel.use_fsdp
# mark whether a module is part of TP or not
is_tensor_parallel_dict = dict()
# TODO: is_tensor_parallel_dict = dict()
# fp16 and fp32 params
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
@ -150,7 +150,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
parameters=params,
last_stage=True
)
return norm_group
def zero_grad(self):
@ -187,12 +187,12 @@ class FSDPadaptOptimizer(BaseOptimizer):
logger.warning("Overflow occurs, please check it.")
self.zero_grad()
return False, None
# get the global norm
global_norm_groups = {}
if self._clip_grad_norm > 0:
for group_name, norm in norm_groups.items():
global_norm_groups[group_name] = norm**0.5
global_norm_groups[group_name] = norm**0.5
# create gradient for fp32 params
for group_idx in range(len(self.param_groups)):
@ -207,7 +207,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
# unscale
self._unscale_and_clip_grads(list(global_norm_groups.values()), loss_scale)
self.optim.step()
self.optim.step()
self.zero_grad()
# update fp16 param
@ -221,7 +221,6 @@ class FSDPadaptOptimizer(BaseOptimizer):
global_norm_groups[group_name] = global_norm / loss_scale
return True, global_norm_groups
def clip_grad_norm(self, model, max_norm):
# will conduct in the step()
pass

View File

@ -42,18 +42,11 @@ from internlm.utils.registry import MODEL_INITIALIZER
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
ShardingStrategy,
MixedPrecision,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
@ -107,23 +100,17 @@ def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
if gpc.config.parallel.use_fsdp:
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls = {PackedFlashBaseLayer1D, PackedFlashInternLm1D}
transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
)
mx = MixedPrecision(
param_dtype=gpc.config.model.dtype, reduce_dtype=gpc.config.model.dtype,
buffer_dtype=gpc.config.model.dtype, keep_low_precision_grads=True)
grp = gpc.get_group(ParallelMode.ZERO1)
model = FSDP(module=model,
model = FSDP(module=model,
process_group=grp,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=transformer_wrap_policy,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
#cpu_offload=CPUOfload(offload_params=True)
#mixed_precision=mx,
#device_id=torch.cuda.current_device()
)
)
return model
@ -159,9 +146,9 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
)
else:
optimizer = FSDPadaptOptimizer(
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
)
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)

View File

@ -57,14 +57,14 @@ def get_model_topology(model):
def get_state_dict(model):
"""
Only used for FSDP module saving.
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
Only used for FSDP module saving.
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
(saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu.
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
"""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig, StateDictType# , FullOptimStateDictConfig
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
@ -91,9 +91,9 @@ def save_model_checkpoint(folder, model):
if gpc.config.parallel.use_fsdp:
states = get_state_dict(model)
else:
else:
states = model.state_dict()
topo = get_model_topology(model)
if folder is not None: