mirror of https://github.com/InternLM/InternLM
feat(*): support no apex (#166)
* support no-apex * add default for use_apex * fix lint * modify the RMSNormTorch * remove some comments * remove use_apex parameter * remove some unnecessary codepull/170/head
parent
66a23e326a
commit
1c397f523f
|
@ -5,7 +5,6 @@ import math
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||
from torch import nn
|
||||
|
@ -20,7 +19,7 @@ from internlm.model.linear import (
|
|||
ScaleColumnParallelLinear,
|
||||
)
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import gather_forward_split_backward
|
||||
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
from internlm.utils.checkpoint import activation_checkpoint
|
||||
from internlm.utils.common import filter_kwargs
|
||||
|
@ -97,6 +96,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
|
||||
self.dropout1 = nn.Dropout(drop_rate)
|
||||
if norm_type == "rmsnorm":
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
else:
|
||||
|
@ -335,6 +335,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
)
|
||||
if last:
|
||||
if norm_type == "rmsnorm":
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm
|
||||
|
||||
import numbers
|
||||
|
||||
import torch
|
||||
from torch.nn import init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
def manual_rms_norm(input, normalized_shape, weight, eps):
|
||||
# layer norm should always be calculated in float32
|
||||
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
|
||||
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
|
||||
input = input * torch.rsqrt(variance + eps)
|
||||
|
||||
if weight is None:
|
||||
return input
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
input = input.to(weight.dtype)
|
||||
|
||||
return weight * input
|
||||
|
||||
|
||||
class RMSNormTorch(torch.nn.Module):
|
||||
def __init__(self, normalized_shape, eps=1e-5):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
self.eps = eps
|
||||
self.weight = Parameter(torch.empty(*normalized_shape))
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
|
||||
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
|
||||
def extra_repr(self):
|
||||
return "{normalized_shape}, eps={eps}, ".format(**self.__dict__)
|
|
@ -71,3 +71,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|||
|
||||
def gather_forward_split_backward(input_, parallel_mode, dim):
|
||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
||||
|
||||
|
||||
def try_import_RMSNorm():
|
||||
"""
|
||||
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
||||
|
||||
"""
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
||||
return RMSNorm
|
||||
except ModuleNotFoundError as e:
|
||||
import warnings
|
||||
warnings.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
||||
from internlm.model.norm import RMSNormTorch as RMSNorm
|
||||
return RMSNorm
|
|
@ -5,10 +5,8 @@ import math
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
|
||||
import amp_C
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
|
@ -156,17 +154,37 @@ def sync_param(flat_tensor, tensor_list):
|
|||
for p, q in zip(tensor_list, updated_params):
|
||||
p.data = q.data
|
||||
|
||||
def multi_tensor_l2norm_torch(tensor_list, per_tensor):
|
||||
# Convert tensor_list elements to torch.float32
|
||||
tensor_list = [tensor.float() for tensor in tensor_list]
|
||||
norms_tensor = torch.stack([torch.norm(tensor, p=2) for tensor in tensor_list])
|
||||
l2_norm = torch.norm(norms_tensor, p=2).unsqueeze(0)
|
||||
|
||||
if per_tensor:
|
||||
per_tensor_norm = norms_tensor
|
||||
else:
|
||||
per_tensor_norm = torch.Tensor([]).to(norms_tensor.device)
|
||||
|
||||
return l2_norm, per_tensor_norm
|
||||
|
||||
def calc_l2_norm(grads):
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||
)
|
||||
try:
|
||||
import amp_C
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
import warnings
|
||||
warnings.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
|
||||
|
||||
norm, _ = multi_tensor_l2norm_torch(grads, False)
|
||||
return norm
|
||||
|
||||
|
||||
def calc_lp(grads, norm_type):
|
||||
norm = 0.0
|
||||
for grad in grads:
|
||||
|
|
Loading…
Reference in New Issue