diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 2175c7c..ce5e993 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -29,7 +29,7 @@ from internlm.utils.registry import MODEL_INITIALIZER MODEL_TYPE = "INTERNLM" logger = get_logger(__file__) - +RMSNorm = try_import_RMSNorm() class PackedFlashBaseLayer1D(nn.Module): """ @@ -96,7 +96,6 @@ class PackedFlashBaseLayer1D(nn.Module): self.dropout1 = nn.Dropout(drop_rate) if norm_type == "rmsnorm": - RMSNorm = try_import_RMSNorm() self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon) else: @@ -335,7 +334,6 @@ class PackedFlashInternLm1D(nn.Module): ) if last: if norm_type == "rmsnorm": - RMSNorm = try_import_RMSNorm() self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) else: self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index c0bb084..ba00a27 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -82,7 +82,8 @@ def try_import_RMSNorm(): from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm return RMSNorm except ModuleNotFoundError as e: - import warnings - warnings.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") + from internlm.utils.logger import get_logger + logger = get_logger(__file__) + logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") from internlm.model.norm import RMSNormTorch as RMSNorm return RMSNorm \ No newline at end of file diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 91a234b..184441b 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -16,10 +16,18 @@ from internlm.utils.common import get_tensor_norm, move_norm_to_cuda from internlm.utils.logger import get_logger from internlm.utils.parallel import is_model_parallel_parameter -inf = math.inf - logger = get_logger(__file__) +try: + import amp_C + from apex.multi_tensor_apply import multi_tensor_applier + APEX_AVAILABLE = True +except (ModuleNotFoundError, ImportError): + logger.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!") + APEX_AVAILABLE = False + +inf = math.inf + def flatten(input_): return _flatten_dense_tensors(input_) @@ -170,18 +178,12 @@ def multi_tensor_l2norm_torch(tensor_list, per_tensor): def calc_l2_norm(grads): norm = 0.0 if len(grads) > 0: - try: - import amp_C - from apex.multi_tensor_apply import multi_tensor_applier - + if APEX_AVAILABLE: dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm ) - except ModuleNotFoundError as e: - import warnings - warnings.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!") - + else: norm, _ = multi_tensor_l2norm_torch(grads, False) return norm