mirror of https://github.com/InternLM/InternLM
refactor(*): refactor the code with no-apex (#170)
* support no-apex * add default for use_apex * fix lint * modify the RMSNormTorch * remove some comments * remove use_apex parameter * remove some unnecessary code * optimize the code including import * remove the import RMSNorm * remove warningspull/155/head
parent
1c397f523f
commit
d67be17f96
|
@ -29,7 +29,7 @@ from internlm.utils.registry import MODEL_INITIALIZER
|
||||||
MODEL_TYPE = "INTERNLM"
|
MODEL_TYPE = "INTERNLM"
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
RMSNorm = try_import_RMSNorm()
|
||||||
|
|
||||||
class PackedFlashBaseLayer1D(nn.Module):
|
class PackedFlashBaseLayer1D(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
@ -96,7 +96,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
|
|
||||||
self.dropout1 = nn.Dropout(drop_rate)
|
self.dropout1 = nn.Dropout(drop_rate)
|
||||||
if norm_type == "rmsnorm":
|
if norm_type == "rmsnorm":
|
||||||
RMSNorm = try_import_RMSNorm()
|
|
||||||
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
else:
|
else:
|
||||||
|
@ -335,7 +334,6 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
)
|
)
|
||||||
if last:
|
if last:
|
||||||
if norm_type == "rmsnorm":
|
if norm_type == "rmsnorm":
|
||||||
RMSNorm = try_import_RMSNorm()
|
|
||||||
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
else:
|
else:
|
||||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
|
|
|
@ -82,7 +82,8 @@ def try_import_RMSNorm():
|
||||||
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
||||||
return RMSNorm
|
return RMSNorm
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
import warnings
|
from internlm.utils.logger import get_logger
|
||||||
warnings.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
logger = get_logger(__file__)
|
||||||
|
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
||||||
from internlm.model.norm import RMSNormTorch as RMSNorm
|
from internlm.model.norm import RMSNormTorch as RMSNorm
|
||||||
return RMSNorm
|
return RMSNorm
|
|
@ -16,10 +16,18 @@ from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.parallel import is_model_parallel_parameter
|
from internlm.utils.parallel import is_model_parallel_parameter
|
||||||
|
|
||||||
inf = math.inf
|
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import amp_C
|
||||||
|
from apex.multi_tensor_apply import multi_tensor_applier
|
||||||
|
APEX_AVAILABLE = True
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
logger.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
|
||||||
|
APEX_AVAILABLE = False
|
||||||
|
|
||||||
|
inf = math.inf
|
||||||
|
|
||||||
|
|
||||||
def flatten(input_):
|
def flatten(input_):
|
||||||
return _flatten_dense_tensors(input_)
|
return _flatten_dense_tensors(input_)
|
||||||
|
@ -170,18 +178,12 @@ def multi_tensor_l2norm_torch(tensor_list, per_tensor):
|
||||||
def calc_l2_norm(grads):
|
def calc_l2_norm(grads):
|
||||||
norm = 0.0
|
norm = 0.0
|
||||||
if len(grads) > 0:
|
if len(grads) > 0:
|
||||||
try:
|
if APEX_AVAILABLE:
|
||||||
import amp_C
|
|
||||||
from apex.multi_tensor_apply import multi_tensor_applier
|
|
||||||
|
|
||||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
norm, _ = multi_tensor_applier(
|
norm, _ = multi_tensor_applier(
|
||||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||||
)
|
)
|
||||||
except ModuleNotFoundError as e:
|
else:
|
||||||
import warnings
|
|
||||||
warnings.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
|
|
||||||
|
|
||||||
norm, _ = multi_tensor_l2norm_torch(grads, False)
|
norm, _ = multi_tensor_l2norm_torch(grads, False)
|
||||||
return norm
|
return norm
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue