diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index bdabbbf..f0a358a 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -328,6 +328,9 @@ class ParallelContext(metaclass=SingletonMeta): return False return self.is_last_rank(ParallelMode.PIPELINE) + def is_no_pp_or_last_stage(self): + return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage() + def get_world_size(self, parallel_mode: ParallelMode): """Returns the world size for `parallel_mode`. diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 0c00bfd..e7d8932 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -16,7 +16,7 @@ from internlm.utils.common import get_master_node from internlm.utils.logger import get_logger from internlm.utils.timeout import llm_timeout -# check pacakge +# check package try: import numa from numa import memory, schedule @@ -78,10 +78,16 @@ def args_sanity_check(): else: pp = gpc.config.parallel.pipeline.size + # check fsdp config if "use_fsdp" not in gpc.config.parallel: gpc.config.parallel._add_item("use_fsdp", False) - - assert not (gpc.config.parallel.use_fsdp and pp > 1), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP" + assert not ( + gpc.config.parallel.use_fsdp and pp > 1 + ), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP" + if gpc.config.parallel.use_fsdp: + assert ( + torch.__version__ >= "2.0.1" + ), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}" # processing the data config in gpc data = gpc.config.data diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 24ce592..3a77f8b 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -6,7 +6,6 @@ from torch_scatter import scatter from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.utils.parallel import is_no_pp_or_last_stage class AccPerplex: @@ -138,7 +137,7 @@ class AccPerplex: self.total_log_probs += total_log_probs def get_metric(self, reset=True): - if is_no_pp_or_last_stage() and self.dp_pg is not None: + if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None: torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) @@ -236,7 +235,7 @@ class LossWithTypeId: self.ds_token_num += token_num_type def get_metric(self, reset=True): - if is_no_pp_or_last_stage() and self.dp_pg is not None: + if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None: torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) if hasattr(self, "total_type_count"): diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 444ce9a..b2f42e0 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -48,11 +48,7 @@ from internlm.train.utils import create_param_groups from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.parallel import ( - is_no_pp_or_last_stage, - sync_model_param, - sync_model_param_within_tp, -) +from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.timeout import llm_timeout @@ -85,7 +81,7 @@ def initialize_model(): else: model = NaiveAMPModel( model=model, - output_to_fp32=is_no_pp_or_last_stage(), + output_to_fp32=gpc.is_no_pp_or_last_stage(), dtype=gpc.config.model.get("dtype", torch.half), sync_buffer=False, ) @@ -113,12 +109,13 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): RMSNorm = try_import_RMSNorm() if gpc.config.parallel.use_fsdp: # pre-save info for tensor parallel - tp_dict = dict() - for name, param in model.named_parameters(): - if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL): - tp_dict.update({name.replace("model.", ""): True}) - else: - tp_dict.update({name.replace("model.", ""): False}) + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + tp_dict = dict() + for name, param in model.named_parameters(): + if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL): + tp_dict.update({name.replace("model.", ""): True}) + else: + tp_dict.update({name.replace("model.", ""): False}) # set wrap_policy for fsdp wrap transformer_wrap_policy = functools.partial( @@ -136,13 +133,14 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): forward_prefetch=True, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, limit_all_gathers=True, - use_orig_params=True, + use_orig_params=False, ) - # re-set attribute for fsdp module - for (name, param), pre in zip(model.named_parameters(), tp_dict): - if pre in name and tp_dict[pre]: - setattr(param, IS_TENSOR_PARALLEL, True) + # re-set attribute for fsdp module with tensor parallel + if gpc.get_world_size(ParallelMode.TENSOR) > 1: + for (name, param), pre in zip(model.named_parameters(), tp_dict): + if pre in name and tp_dict[pre]: + setattr(param, IS_TENSOR_PARALLEL, True) return model @@ -421,7 +419,7 @@ def record_current_batch_training_metrics( timer.store_last_timers() if success_update in (0, True): train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) - if is_no_pp_or_last_stage(): + if gpc.is_no_pp_or_last_stage(): acc_perplex = metric.get_metric() if success_update and gpc.is_rank_for_log(): diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 3a10227..69e1170 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -51,10 +51,6 @@ def sync_model_param_within_tp(model): dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) -def is_no_pp_or_last_stage(): - return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) - - def get_parallel_log_file_name(): if gpc.is_rank_for_log(): fn_prefix = "main_" # Indicates a rank with more output information diff --git a/train.py b/train.py index 71ce548..139bac1 100644 --- a/train.py +++ b/train.py @@ -307,9 +307,6 @@ if __name__ == "__main__": # initialize distributed environment initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) assert hasattr(gpc, "config") and gpc.config is not None - if gpc.config.parallel.use_fsdp: - assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}" - # initialize monitor manager context with initialize_monitor_manager(