mirror of https://github.com/InternLM/InternLM
feat(config): add checkpoint_fraction into config
parent
81b10e81d9
commit
caa13f0dae
|
@ -98,6 +98,7 @@ beta2_scheduler = dict(
|
|||
|
||||
model = dict(
|
||||
checkpoint=False,
|
||||
checkpoint_fraction=0,
|
||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||
embed_split_hidden=True,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
|
|
|
@ -138,16 +138,34 @@ def args_sanity_check():
|
|||
logger.info(f"cudnn.deterministic: {torch.backends.cudnn.deterministic }")
|
||||
logger.info(f"clip_grad_norm: {clip_grad_norm}")
|
||||
|
||||
if "dtype" not in gpc.config.model:
|
||||
model = gpc.config.model
|
||||
if "dtype" not in model:
|
||||
logger.warning("dtype is not set, use torch.float16 by defalut!")
|
||||
gpc.config.model._add_item("dtype", torch.float16)
|
||||
model._add_item("dtype", torch.float16)
|
||||
else:
|
||||
if gpc.config.model.dtype == "torch.bfloat16":
|
||||
gpc.config.model.dtype = torch.bfloat16
|
||||
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
|
||||
gpc.config.model.dtype = torch.float16
|
||||
if model.dtype == "torch.bfloat16":
|
||||
model.dtype = torch.bfloat16
|
||||
elif model.dtype in ("torch.float16", "torch.half"):
|
||||
model.dtype = torch.float16
|
||||
else:
|
||||
assert gpc.config.model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"]
|
||||
assert model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"]
|
||||
|
||||
if "checkpoint_fraction" in model:
|
||||
if model.checkpoint_fraction <= 0:
|
||||
model._add_item("checkpoint", False)
|
||||
elif model.checkpoint_fraction <= 1:
|
||||
model._add_item("checkpoint", True)
|
||||
else:
|
||||
raise RuntimeError("checkpoint_fraction must between [0-1]")
|
||||
else:
|
||||
if "checkpoint" in model:
|
||||
if model.checkpoint is True:
|
||||
model._add_item("checkpoint_fraction", 1)
|
||||
else:
|
||||
model._add_item("checkpoint_fraction", 0)
|
||||
else:
|
||||
model._add_item("checkpoint", False)
|
||||
model._add_item("checkpoint_fraction", 0)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201
|
||||
|
|
|
@ -258,7 +258,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
drop_rate: float = 0.0,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
checkpoint_fraction: float = 1.0,
|
||||
checkpoint_fraction: float = 0,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
|
@ -276,11 +276,9 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
|
||||
if checkpoint_fraction <= 0:
|
||||
checkpoint = False
|
||||
if not checkpoint:
|
||||
checkpoint_fraction = 0
|
||||
checkpoint_layer_num = num_layers * checkpoint_fraction
|
||||
checkpoint_layer_num = num_layers * checkpoint_fraction if checkpoint else 0
|
||||
print(f"checkpoint_layer_num: {checkpoint_layer_num}", flush=True)
|
||||
|
||||
if is_reward:
|
||||
head_cls = RewardModelLinear
|
||||
else:
|
||||
|
@ -408,11 +406,6 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
|||
|
||||
models = []
|
||||
|
||||
if kwargs["checkpoint"] is True:
|
||||
kwargs["checkpoint_fraction"] = 1.0
|
||||
else:
|
||||
kwargs["checkpoint_fraction"] = 0
|
||||
|
||||
for start, end in parts:
|
||||
kwargs["num_layers"] = end - start
|
||||
kwargs["first"] = start == 0
|
||||
|
@ -436,6 +429,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
|||
def build_model_with_cfg(
|
||||
num_chunks=1,
|
||||
checkpoint=False,
|
||||
checkpoint_fraction=0,
|
||||
dtype=torch.float,
|
||||
embed_split_hidden=False,
|
||||
num_layers=48,
|
||||
|
@ -491,6 +485,7 @@ def build_model_with_cfg(
|
|||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
checkpoint=checkpoint,
|
||||
checkpoint_fraction=checkpoint_fraction,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
vocab_size=vocab_size,
|
||||
|
|
Loading…
Reference in New Issue