mirror of https://github.com/InternLM/InternLM
Feat add checkpoint fraction (#151)
* feat(config): add checkpoint_fraction into config * feat: remove checkpoint_fraction from configs/7B_sft.py --------- Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>pull/159/head
parent
2fee4220a6
commit
6b6295aea3
|
@ -97,7 +97,7 @@ beta2_scheduler = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
checkpoint=False,
|
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
||||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||||
embed_split_hidden=True,
|
embed_split_hidden=True,
|
||||||
vocab_size=VOCAB_SIZE,
|
vocab_size=VOCAB_SIZE,
|
||||||
|
|
|
@ -140,7 +140,7 @@ HIDDEN_SIZE = 4096
|
||||||
NUM_LAYER = 32
|
NUM_LAYER = 32
|
||||||
MLP_RATIO = 8 / 3
|
MLP_RATIO = 8 / 3
|
||||||
model = dict(
|
model = dict(
|
||||||
checkpoint=False,
|
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
||||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||||
embed_split_hidden=True,
|
embed_split_hidden=True,
|
||||||
vocab_size=VOCAB_SIZE,
|
vocab_size=VOCAB_SIZE,
|
||||||
|
|
|
@ -126,7 +126,7 @@ HIDDEN_SIZE = 4096
|
||||||
NUM_LAYER = 32
|
NUM_LAYER = 32
|
||||||
MLP_RATIO = 8 / 3
|
MLP_RATIO = 8 / 3
|
||||||
model = dict(
|
model = dict(
|
||||||
checkpoint=False,
|
checkpoint=False, # 进行重计算的模型层数比例,可选值为 True/False/[0-1]
|
||||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||||
embed_split_hidden=True,
|
embed_split_hidden=True,
|
||||||
vocab_size=VOCAB_SIZE,
|
vocab_size=VOCAB_SIZE,
|
||||||
|
|
|
@ -138,16 +138,27 @@ def args_sanity_check():
|
||||||
logger.info(f"cudnn.deterministic: {torch.backends.cudnn.deterministic }")
|
logger.info(f"cudnn.deterministic: {torch.backends.cudnn.deterministic }")
|
||||||
logger.info(f"clip_grad_norm: {clip_grad_norm}")
|
logger.info(f"clip_grad_norm: {clip_grad_norm}")
|
||||||
|
|
||||||
if "dtype" not in gpc.config.model:
|
model = gpc.config.model
|
||||||
|
if "dtype" not in model:
|
||||||
logger.warning("dtype is not set, use torch.float16 by defalut!")
|
logger.warning("dtype is not set, use torch.float16 by defalut!")
|
||||||
gpc.config.model._add_item("dtype", torch.float16)
|
model._add_item("dtype", torch.float16)
|
||||||
else:
|
else:
|
||||||
if gpc.config.model.dtype == "torch.bfloat16":
|
if model.dtype == "torch.bfloat16":
|
||||||
gpc.config.model.dtype = torch.bfloat16
|
model.dtype = torch.bfloat16
|
||||||
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
|
elif model.dtype in ("torch.float16", "torch.half"):
|
||||||
gpc.config.model.dtype = torch.float16
|
model.dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
assert gpc.config.model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"]
|
assert model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"]
|
||||||
|
|
||||||
|
if "checkpoint" in model:
|
||||||
|
if model.checkpoint is True:
|
||||||
|
model.checkpoint = 1
|
||||||
|
elif model.checkpoint is False:
|
||||||
|
model.checkpoint = 0
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
model.checkpoint >= 0 and model.checkpoint <= 1
|
||||||
|
), f'model.checkpoint: "{model.checkpoint}" should >=0 and <=1'
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201
|
logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201
|
||||||
|
|
|
@ -230,9 +230,8 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
|
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
|
||||||
drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
|
drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
|
||||||
dtype (torch.dtype): The type of data. torch.float by default.
|
dtype (torch.dtype): The type of data. torch.float by default.
|
||||||
checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
|
checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number
|
||||||
checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number
|
of layers. 0.0 by default.
|
||||||
of layers. 1.0 by default.
|
|
||||||
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
|
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
|
||||||
first (bool): Whether input embedding layer or not. False by default.
|
first (bool): Whether input embedding layer or not. False by default.
|
||||||
last (bool): Whether output embedding layer or not. False by default.
|
last (bool): Whether output embedding layer or not. False by default.
|
||||||
|
@ -257,8 +256,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
attn_drop_rate: float = 0.0,
|
attn_drop_rate: float = 0.0,
|
||||||
drop_rate: float = 0.0,
|
drop_rate: float = 0.0,
|
||||||
dtype: torch.dtype = torch.float,
|
dtype: torch.dtype = torch.float,
|
||||||
checkpoint: bool = False,
|
checkpoint: float = 0.0,
|
||||||
checkpoint_fraction: float = 1.0,
|
|
||||||
layer_norm_epsilon: float = 1e-5,
|
layer_norm_epsilon: float = 1e-5,
|
||||||
first: bool = False,
|
first: bool = False,
|
||||||
last: bool = False,
|
last: bool = False,
|
||||||
|
@ -276,11 +274,8 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if checkpoint_fraction <= 0:
|
checkpoint_layer_num = int(num_layers * checkpoint)
|
||||||
checkpoint = False
|
|
||||||
if not checkpoint:
|
|
||||||
checkpoint_fraction = 0
|
|
||||||
checkpoint_layer_num = num_layers * checkpoint_fraction
|
|
||||||
if is_reward:
|
if is_reward:
|
||||||
head_cls = RewardModelLinear
|
head_cls = RewardModelLinear
|
||||||
else:
|
else:
|
||||||
|
@ -408,11 +403,6 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
|
|
||||||
if kwargs["checkpoint"] is True:
|
|
||||||
kwargs["checkpoint_fraction"] = 1.0
|
|
||||||
else:
|
|
||||||
kwargs["checkpoint_fraction"] = 0
|
|
||||||
|
|
||||||
for start, end in parts:
|
for start, end in parts:
|
||||||
kwargs["num_layers"] = end - start
|
kwargs["num_layers"] = end - start
|
||||||
kwargs["first"] = start == 0
|
kwargs["first"] = start == 0
|
||||||
|
@ -435,7 +425,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
||||||
@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE)
|
@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE)
|
||||||
def build_model_with_cfg(
|
def build_model_with_cfg(
|
||||||
num_chunks=1,
|
num_chunks=1,
|
||||||
checkpoint=False,
|
checkpoint=0.0,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
embed_split_hidden=False,
|
embed_split_hidden=False,
|
||||||
num_layers=48,
|
num_layers=48,
|
||||||
|
|
Loading…
Reference in New Issue