mirror of https://github.com/InternLM/InternLM
refactor(*): refactor the code with no-apex (#170)
* support no-apex * add default for use_apex * fix lint * modify the RMSNormTorch * remove some comments * remove use_apex parameter * remove some unnecessary code * optimize the code including import * remove the import RMSNorm * remove warningspull/155/head
parent
1c397f523f
commit
d67be17f96
|
@ -29,7 +29,7 @@ from internlm.utils.registry import MODEL_INITIALIZER
|
|||
MODEL_TYPE = "INTERNLM"
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
|
||||
class PackedFlashBaseLayer1D(nn.Module):
|
||||
"""
|
||||
|
@ -96,7 +96,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
|
||||
self.dropout1 = nn.Dropout(drop_rate)
|
||||
if norm_type == "rmsnorm":
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
else:
|
||||
|
@ -335,7 +334,6 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
)
|
||||
if last:
|
||||
if norm_type == "rmsnorm":
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
|
|
|
@ -82,7 +82,8 @@ def try_import_RMSNorm():
|
|||
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
||||
return RMSNorm
|
||||
except ModuleNotFoundError as e:
|
||||
import warnings
|
||||
warnings.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
||||
from internlm.utils.logger import get_logger
|
||||
logger = get_logger(__file__)
|
||||
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
||||
from internlm.model.norm import RMSNormTorch as RMSNorm
|
||||
return RMSNorm
|
|
@ -16,10 +16,18 @@ from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
|
|||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.parallel import is_model_parallel_parameter
|
||||
|
||||
inf = math.inf
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
try:
|
||||
import amp_C
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
APEX_AVAILABLE = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
|
||||
APEX_AVAILABLE = False
|
||||
|
||||
inf = math.inf
|
||||
|
||||
|
||||
def flatten(input_):
|
||||
return _flatten_dense_tensors(input_)
|
||||
|
@ -170,18 +178,12 @@ def multi_tensor_l2norm_torch(tensor_list, per_tensor):
|
|||
def calc_l2_norm(grads):
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
try:
|
||||
import amp_C
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
if APEX_AVAILABLE:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
import warnings
|
||||
warnings.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
|
||||
|
||||
else:
|
||||
norm, _ = multi_tensor_l2norm_torch(grads, False)
|
||||
return norm
|
||||
|
||||
|
|
Loading…
Reference in New Issue