fix failure in lint-check

pull/293/head
zaglc 2023-09-08 13:19:42 +08:00
parent 31d2a2916d
commit 1d60f90ed9
4 changed files with 26 additions and 41 deletions

View File

@ -392,7 +392,7 @@ class Initializer_Nettest(ProcessGroupInitializer):
ranks_in_group = ranks ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_Zero3_dp(ProcessGroupInitializer): class Initializer_Zero3_dp(ProcessGroupInitializer):
"""A ProcessGroupInitializer for data parallelism. """A ProcessGroupInitializer for data parallelism.
@ -421,7 +421,6 @@ class Initializer_Zero3_dp(ProcessGroupInitializer):
assert self.world_size % self.data_parallel_size == 0 assert self.world_size % self.data_parallel_size == 0
def init_dist_group(self, use_cpu: bool = False): def init_dist_group(self, use_cpu: bool = False):
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu. """Initialize data parallel groups, and assign local_ranks and groups to each gpu.

View File

@ -90,13 +90,13 @@ class FSDPadaptOptimizer(BaseOptimizer):
''' '''
def __init__( def __init__(
self, self,
optimizer: Optimizer, optimizer: Optimizer,
grad_scal_cfg: Config = None, grad_scal_cfg: Config = None,
zero_cfg: Config = None, zero_cfg: Config = None,
): ):
super().__init__(optim=optimizer) super().__init__(optim=optimizer)
# gradient scaler # gradient scaler
self.grad_scaler = DynamicGradScaler( self.grad_scaler = DynamicGradScaler(
initial_scale=grad_scal_cfg.fp16.initial_scale, initial_scale=grad_scal_cfg.fp16.initial_scale,
@ -113,7 +113,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
self.use_fsdp = gpc.config.parallel.use_fsdp self.use_fsdp = gpc.config.parallel.use_fsdp
# mark whether a module is part of TP or not # mark whether a module is part of TP or not
is_tensor_parallel_dict = dict() # TODO: is_tensor_parallel_dict = dict()
# fp16 and fp32 params # fp16 and fp32 params
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group # fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
@ -150,7 +150,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
parameters=params, parameters=params,
last_stage=True last_stage=True
) )
return norm_group return norm_group
def zero_grad(self): def zero_grad(self):
@ -187,12 +187,12 @@ class FSDPadaptOptimizer(BaseOptimizer):
logger.warning("Overflow occurs, please check it.") logger.warning("Overflow occurs, please check it.")
self.zero_grad() self.zero_grad()
return False, None return False, None
# get the global norm # get the global norm
global_norm_groups = {} global_norm_groups = {}
if self._clip_grad_norm > 0: if self._clip_grad_norm > 0:
for group_name, norm in norm_groups.items(): for group_name, norm in norm_groups.items():
global_norm_groups[group_name] = norm**0.5 global_norm_groups[group_name] = norm**0.5
# create gradient for fp32 params # create gradient for fp32 params
for group_idx in range(len(self.param_groups)): for group_idx in range(len(self.param_groups)):
@ -207,7 +207,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
# unscale # unscale
self._unscale_and_clip_grads(list(global_norm_groups.values()), loss_scale) self._unscale_and_clip_grads(list(global_norm_groups.values()), loss_scale)
self.optim.step() self.optim.step()
self.zero_grad() self.zero_grad()
# update fp16 param # update fp16 param
@ -221,7 +221,6 @@ class FSDPadaptOptimizer(BaseOptimizer):
global_norm_groups[group_name] = global_norm / loss_scale global_norm_groups[group_name] = global_norm / loss_scale
return True, global_norm_groups return True, global_norm_groups
def clip_grad_norm(self, model, max_norm): def clip_grad_norm(self, model, max_norm):
# will conduct in the step() # will conduct in the step()
pass pass

View File

@ -42,18 +42,11 @@ from internlm.utils.registry import MODEL_INITIALIZER
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
ShardingStrategy, ShardingStrategy,
MixedPrecision, MixedPrecision,
BackwardPrefetch, BackwardPrefetch,
) )
from torch.distributed.fsdp.wrap import ( from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
import functools import functools
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D
@ -107,23 +100,17 @@ def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
if gpc.config.parallel.use_fsdp: if gpc.config.parallel.use_fsdp:
transformer_wrap_policy = functools.partial( transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy, transformer_auto_wrap_policy,
transformer_layer_cls = {PackedFlashBaseLayer1D, PackedFlashInternLm1D} transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
) )
mx = MixedPrecision(
param_dtype=gpc.config.model.dtype, reduce_dtype=gpc.config.model.dtype,
buffer_dtype=gpc.config.model.dtype, keep_low_precision_grads=True)
grp = gpc.get_group(ParallelMode.ZERO1) grp = gpc.get_group(ParallelMode.ZERO1)
model = FSDP(module=model, model = FSDP(module=model,
process_group=grp, process_group=grp,
sharding_strategy=ShardingStrategy.FULL_SHARD, sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=transformer_wrap_policy, auto_wrap_policy=transformer_wrap_policy,
forward_prefetch=True, forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
#cpu_offload=CPUOfload(offload_params=True) )
#mixed_precision=mx,
#device_id=torch.cuda.current_device()
)
return model return model
@ -159,9 +146,9 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
) )
else: else:
optimizer = FSDPadaptOptimizer( optimizer = FSDPadaptOptimizer(
naive_optimizer, naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler, grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer, zero_cfg=gpc.config.hybrid_zero_optimizer,
) )
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)

View File

@ -57,14 +57,14 @@ def get_model_topology(model):
def get_state_dict(model): def get_state_dict(model):
""" """
Only used for FSDP module saving. Only used for FSDP module saving.
It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter
(saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu. (saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu.
'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu 'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu
""" """
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig, StateDictType# , FullOptimStateDictConfig from torch.distributed.fsdp import FullStateDictConfig, StateDictType
# TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters # TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)
@ -91,9 +91,9 @@ def save_model_checkpoint(folder, model):
if gpc.config.parallel.use_fsdp: if gpc.config.parallel.use_fsdp:
states = get_state_dict(model) states = get_state_dict(model)
else: else:
states = model.state_dict() states = model.state_dict()
topo = get_model_topology(model) topo = get_model_topology(model)
if folder is not None: if folder is not None: