diff --git a/configs/7B_sft.py b/configs/7B_sft.py index b757ae6..0061685 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -98,6 +98,7 @@ beta2_scheduler = dict( model = dict( checkpoint=False, + checkpoint_fraction=0, num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, vocab_size=VOCAB_SIZE, diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index d6eda24..918456b 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -138,16 +138,34 @@ def args_sanity_check(): logger.info(f"cudnn.deterministic: {torch.backends.cudnn.deterministic }") logger.info(f"clip_grad_norm: {clip_grad_norm}") - if "dtype" not in gpc.config.model: + model = gpc.config.model + if "dtype" not in model: logger.warning("dtype is not set, use torch.float16 by defalut!") - gpc.config.model._add_item("dtype", torch.float16) + model._add_item("dtype", torch.float16) else: - if gpc.config.model.dtype == "torch.bfloat16": - gpc.config.model.dtype = torch.bfloat16 - elif gpc.config.model.dtype in ("torch.float16", "torch.half"): - gpc.config.model.dtype = torch.float16 + if model.dtype == "torch.bfloat16": + model.dtype = torch.bfloat16 + elif model.dtype in ("torch.float16", "torch.half"): + model.dtype = torch.float16 else: - assert gpc.config.model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"] + assert model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"] + + if "checkpoint_fraction" in model: + if model.checkpoint_fraction <= 0: + model._add_item("checkpoint", False) + elif model.checkpoint_fraction <= 1: + model._add_item("checkpoint", True) + else: + raise RuntimeError("checkpoint_fraction must between [0-1]") + else: + if "checkpoint" in model: + if model.checkpoint is True: + model._add_item("checkpoint_fraction", 1) + else: + model._add_item("checkpoint_fraction", 0) + else: + model._add_item("checkpoint", False) + model._add_item("checkpoint_fraction", 0) if gpc.is_rank_for_log(): logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201 diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index e8cc2a9..d6c0806 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -258,7 +258,7 @@ class PackedFlashInternLm1D(nn.Module): drop_rate: float = 0.0, dtype: torch.dtype = torch.float, checkpoint: bool = False, - checkpoint_fraction: float = 1.0, + checkpoint_fraction: float = 0, layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, @@ -276,11 +276,9 @@ class PackedFlashInternLm1D(nn.Module): ): super().__init__() - if checkpoint_fraction <= 0: - checkpoint = False - if not checkpoint: - checkpoint_fraction = 0 - checkpoint_layer_num = num_layers * checkpoint_fraction + checkpoint_layer_num = num_layers * checkpoint_fraction if checkpoint else 0 + print(f"checkpoint_layer_num: {checkpoint_layer_num}", flush=True) + if is_reward: head_cls = RewardModelLinear else: @@ -408,11 +406,6 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), models = [] - if kwargs["checkpoint"] is True: - kwargs["checkpoint_fraction"] = 1.0 - else: - kwargs["checkpoint_fraction"] = 0 - for start, end in parts: kwargs["num_layers"] = end - start kwargs["first"] = start == 0 @@ -436,6 +429,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), def build_model_with_cfg( num_chunks=1, checkpoint=False, + checkpoint_fraction=0, dtype=torch.float, embed_split_hidden=False, num_layers=48, @@ -491,6 +485,7 @@ def build_model_with_cfg( hidden_size=hidden_size, num_attention_heads=num_attention_heads, checkpoint=checkpoint, + checkpoint_fraction=checkpoint_fraction, dtype=dtype, embed_split_hidden=embed_split_hidden, vocab_size=vocab_size,