fix(model/modeling_internlm.py): fix checkpoint=False runtime error

pull/293/head
huangting4201 2023-09-27 11:18:04 +08:00
parent c703938fb3
commit 59b7530129
2 changed files with 2 additions and 2 deletions

View File

@ -450,7 +450,7 @@ def build_model_with_cfg(
apply_post_layer_norm=False, # pylint: disable=W0613
layer_norm_epsilon=1e-5,
is_reward=False,
dropout_selective_checkpoint=True,
dropout_selective_checkpoint=False,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,

View File

@ -66,7 +66,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
self._fp16_param_groups[group_idx] = group_params
# create copy of fp32 weight
fp32_tensor_param = [param.data.float().requires_grad_(True) for param in group_params]
fp32_tensor_param = [param.data.float() for param in group_params]
self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param
# replace