mirror of https://github.com/InternLM/InternLM
feat(*): support no apex (#166)
* support no-apex * add default for use_apex * fix lint * modify the RMSNormTorch * remove some comments * remove use_apex parameter * remove some unnecessary codepull/170/head
parent
66a23e326a
commit
1c397f523f
|
@ -5,7 +5,6 @@ import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
|
||||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -20,7 +19,7 @@ from internlm.model.linear import (
|
||||||
ScaleColumnParallelLinear,
|
ScaleColumnParallelLinear,
|
||||||
)
|
)
|
||||||
from internlm.model.multi_head_attention import MHA
|
from internlm.model.multi_head_attention import MHA
|
||||||
from internlm.model.utils import gather_forward_split_backward
|
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
|
||||||
from internlm.solver.pipeline_utils import partition_uniform
|
from internlm.solver.pipeline_utils import partition_uniform
|
||||||
from internlm.utils.checkpoint import activation_checkpoint
|
from internlm.utils.checkpoint import activation_checkpoint
|
||||||
from internlm.utils.common import filter_kwargs
|
from internlm.utils.common import filter_kwargs
|
||||||
|
@ -97,6 +96,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
|
|
||||||
self.dropout1 = nn.Dropout(drop_rate)
|
self.dropout1 = nn.Dropout(drop_rate)
|
||||||
if norm_type == "rmsnorm":
|
if norm_type == "rmsnorm":
|
||||||
|
RMSNorm = try_import_RMSNorm()
|
||||||
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
else:
|
else:
|
||||||
|
@ -335,6 +335,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
)
|
)
|
||||||
if last:
|
if last:
|
||||||
if norm_type == "rmsnorm":
|
if norm_type == "rmsnorm":
|
||||||
|
RMSNorm = try_import_RMSNorm()
|
||||||
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
else:
|
else:
|
||||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
# adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm
|
||||||
|
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import init
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
def manual_rms_norm(input, normalized_shape, weight, eps):
|
||||||
|
# layer norm should always be calculated in float32
|
||||||
|
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
|
||||||
|
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
|
||||||
|
input = input * torch.rsqrt(variance + eps)
|
||||||
|
|
||||||
|
if weight is None:
|
||||||
|
return input
|
||||||
|
|
||||||
|
# convert into half-precision if necessary
|
||||||
|
if weight.dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
input = input.to(weight.dtype)
|
||||||
|
|
||||||
|
return weight * input
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNormTorch(torch.nn.Module):
|
||||||
|
def __init__(self, normalized_shape, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
normalized_shape = (normalized_shape,)
|
||||||
|
self.normalized_shape = torch.Size(normalized_shape)
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = Parameter(torch.empty(*normalized_shape))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
init.ones_(self.weight)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return "{normalized_shape}, eps={eps}, ".format(**self.__dict__)
|
|
@ -71,3 +71,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
|
|
||||||
def gather_forward_split_backward(input_, parallel_mode, dim):
|
def gather_forward_split_backward(input_, parallel_mode, dim):
|
||||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def try_import_RMSNorm():
|
||||||
|
"""
|
||||||
|
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
||||||
|
return RMSNorm
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
import warnings
|
||||||
|
warnings.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
||||||
|
from internlm.model.norm import RMSNormTorch as RMSNorm
|
||||||
|
return RMSNorm
|
|
@ -5,10 +5,8 @@ import math
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import amp_C
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from apex.multi_tensor_apply import multi_tensor_applier
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
|
||||||
|
@ -156,17 +154,37 @@ def sync_param(flat_tensor, tensor_list):
|
||||||
for p, q in zip(tensor_list, updated_params):
|
for p, q in zip(tensor_list, updated_params):
|
||||||
p.data = q.data
|
p.data = q.data
|
||||||
|
|
||||||
|
def multi_tensor_l2norm_torch(tensor_list, per_tensor):
|
||||||
|
# Convert tensor_list elements to torch.float32
|
||||||
|
tensor_list = [tensor.float() for tensor in tensor_list]
|
||||||
|
norms_tensor = torch.stack([torch.norm(tensor, p=2) for tensor in tensor_list])
|
||||||
|
l2_norm = torch.norm(norms_tensor, p=2).unsqueeze(0)
|
||||||
|
|
||||||
|
if per_tensor:
|
||||||
|
per_tensor_norm = norms_tensor
|
||||||
|
else:
|
||||||
|
per_tensor_norm = torch.Tensor([]).to(norms_tensor.device)
|
||||||
|
|
||||||
|
return l2_norm, per_tensor_norm
|
||||||
|
|
||||||
def calc_l2_norm(grads):
|
def calc_l2_norm(grads):
|
||||||
norm = 0.0
|
norm = 0.0
|
||||||
if len(grads) > 0:
|
if len(grads) > 0:
|
||||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
try:
|
||||||
norm, _ = multi_tensor_applier(
|
import amp_C
|
||||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
from apex.multi_tensor_apply import multi_tensor_applier
|
||||||
)
|
|
||||||
|
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
|
norm, _ = multi_tensor_applier(
|
||||||
|
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||||
|
)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
import warnings
|
||||||
|
warnings.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
|
||||||
|
|
||||||
|
norm, _ = multi_tensor_l2norm_torch(grads, False)
|
||||||
return norm
|
return norm
|
||||||
|
|
||||||
|
|
||||||
def calc_lp(grads, norm_type):
|
def calc_lp(grads, norm_type):
|
||||||
norm = 0.0
|
norm = 0.0
|
||||||
for grad in grads:
|
for grad in grads:
|
||||||
|
|
Loading…
Reference in New Issue