refactor(*): refactor the code with no-apex (#170)

* support no-apex

* add default for use_apex

* fix lint

* modify the RMSNormTorch

* remove some comments

* remove use_apex parameter

* remove some unnecessary code

* optimize the code including import

* remove the import RMSNorm

* remove warnings
pull/155/head
ytxiong 2023-08-03 11:24:12 +08:00 committed by GitHub
parent 1c397f523f
commit d67be17f96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 15 deletions

View File

@ -29,7 +29,7 @@ from internlm.utils.registry import MODEL_INITIALIZER
MODEL_TYPE = "INTERNLM"
logger = get_logger(__file__)
RMSNorm = try_import_RMSNorm()
class PackedFlashBaseLayer1D(nn.Module):
"""
@ -96,7 +96,6 @@ class PackedFlashBaseLayer1D(nn.Module):
self.dropout1 = nn.Dropout(drop_rate)
if norm_type == "rmsnorm":
RMSNorm = try_import_RMSNorm()
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else:
@ -335,7 +334,6 @@ class PackedFlashInternLm1D(nn.Module):
)
if last:
if norm_type == "rmsnorm":
RMSNorm = try_import_RMSNorm()
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else:
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)

View File

@ -82,7 +82,8 @@ def try_import_RMSNorm():
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
return RMSNorm
except ModuleNotFoundError as e:
import warnings
warnings.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
from internlm.model.norm import RMSNormTorch as RMSNorm
return RMSNorm

View File

@ -16,10 +16,18 @@ from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_model_parallel_parameter
inf = math.inf
logger = get_logger(__file__)
try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
APEX_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
logger.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
APEX_AVAILABLE = False
inf = math.inf
def flatten(input_):
return _flatten_dense_tensors(input_)
@ -170,18 +178,12 @@ def multi_tensor_l2norm_torch(tensor_list, per_tensor):
def calc_l2_norm(grads):
norm = 0.0
if len(grads) > 0:
try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
if APEX_AVAILABLE:
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
)
except ModuleNotFoundError as e:
import warnings
warnings.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
else:
norm, _ = multi_tensor_l2norm_torch(grads, False)
return norm