feat(config): add checkpoint_fraction into config

pull/36/head
wangguoteng.p 2023-07-08 03:59:18 +08:00
parent 81b10e81d9
commit caa13f0dae
3 changed files with 32 additions and 18 deletions

View File

@ -98,6 +98,7 @@ beta2_scheduler = dict(
model = dict(
checkpoint=False,
checkpoint_fraction=0,
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,

View File

@ -138,16 +138,34 @@ def args_sanity_check():
logger.info(f"cudnn.deterministic: {torch.backends.cudnn.deterministic }")
logger.info(f"clip_grad_norm: {clip_grad_norm}")
if "dtype" not in gpc.config.model:
model = gpc.config.model
if "dtype" not in model:
logger.warning("dtype is not set, use torch.float16 by defalut!")
gpc.config.model._add_item("dtype", torch.float16)
model._add_item("dtype", torch.float16)
else:
if gpc.config.model.dtype == "torch.bfloat16":
gpc.config.model.dtype = torch.bfloat16
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
gpc.config.model.dtype = torch.float16
if model.dtype == "torch.bfloat16":
model.dtype = torch.bfloat16
elif model.dtype in ("torch.float16", "torch.half"):
model.dtype = torch.float16
else:
assert gpc.config.model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"]
assert model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"]
if "checkpoint_fraction" in model:
if model.checkpoint_fraction <= 0:
model._add_item("checkpoint", False)
elif model.checkpoint_fraction <= 1:
model._add_item("checkpoint", True)
else:
raise RuntimeError("checkpoint_fraction must between [0-1]")
else:
if "checkpoint" in model:
if model.checkpoint is True:
model._add_item("checkpoint_fraction", 1)
else:
model._add_item("checkpoint_fraction", 0)
else:
model._add_item("checkpoint", False)
model._add_item("checkpoint_fraction", 0)
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201

View File

@ -258,7 +258,7 @@ class PackedFlashInternLm1D(nn.Module):
drop_rate: float = 0.0,
dtype: torch.dtype = torch.float,
checkpoint: bool = False,
checkpoint_fraction: float = 1.0,
checkpoint_fraction: float = 0,
layer_norm_epsilon: float = 1e-5,
first: bool = False,
last: bool = False,
@ -276,11 +276,9 @@ class PackedFlashInternLm1D(nn.Module):
):
super().__init__()
if checkpoint_fraction <= 0:
checkpoint = False
if not checkpoint:
checkpoint_fraction = 0
checkpoint_layer_num = num_layers * checkpoint_fraction
checkpoint_layer_num = num_layers * checkpoint_fraction if checkpoint else 0
print(f"checkpoint_layer_num: {checkpoint_layer_num}", flush=True)
if is_reward:
head_cls = RewardModelLinear
else:
@ -408,11 +406,6 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
models = []
if kwargs["checkpoint"] is True:
kwargs["checkpoint_fraction"] = 1.0
else:
kwargs["checkpoint_fraction"] = 0
for start, end in parts:
kwargs["num_layers"] = end - start
kwargs["first"] = start == 0
@ -436,6 +429,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
def build_model_with_cfg(
num_chunks=1,
checkpoint=False,
checkpoint_fraction=0,
dtype=torch.float,
embed_split_hidden=False,
num_layers=48,
@ -491,6 +485,7 @@ def build_model_with_cfg(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
checkpoint=checkpoint,
checkpoint_fraction=checkpoint_fraction,
dtype=dtype,
embed_split_hidden=embed_split_hidden,
vocab_size=vocab_size,