feat(*): support no apex (#166)

* support no-apex

* add default for use_apex

* fix lint

* modify the RMSNormTorch

* remove some comments

* remove use_apex parameter

* remove some unnecessary code
pull/170/head
ytxiong 2023-08-02 20:32:38 +08:00 committed by GitHub
parent 66a23e326a
commit 1c397f523f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 9 deletions

View File

@ -5,7 +5,6 @@ import math
from typing import Optional from typing import Optional
import torch import torch
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
from flash_attn.modules.embedding import ParallelGPT2Embeddings from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP from flash_attn.modules.mlp import ParallelFusedMLP
from torch import nn from torch import nn
@ -20,7 +19,7 @@ from internlm.model.linear import (
ScaleColumnParallelLinear, ScaleColumnParallelLinear,
) )
from internlm.model.multi_head_attention import MHA from internlm.model.multi_head_attention import MHA
from internlm.model.utils import gather_forward_split_backward from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
from internlm.solver.pipeline_utils import partition_uniform from internlm.solver.pipeline_utils import partition_uniform
from internlm.utils.checkpoint import activation_checkpoint from internlm.utils.checkpoint import activation_checkpoint
from internlm.utils.common import filter_kwargs from internlm.utils.common import filter_kwargs
@ -97,6 +96,7 @@ class PackedFlashBaseLayer1D(nn.Module):
self.dropout1 = nn.Dropout(drop_rate) self.dropout1 = nn.Dropout(drop_rate)
if norm_type == "rmsnorm": if norm_type == "rmsnorm":
RMSNorm = try_import_RMSNorm()
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon) self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else: else:
@ -335,6 +335,7 @@ class PackedFlashInternLm1D(nn.Module):
) )
if last: if last:
if norm_type == "rmsnorm": if norm_type == "rmsnorm":
RMSNorm = try_import_RMSNorm()
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else: else:
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)

44
internlm/model/norm.py Normal file
View File

@ -0,0 +1,44 @@
# adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm
import numbers
import torch
from torch.nn import init
from torch.nn.parameter import Parameter
def manual_rms_norm(input, normalized_shape, weight, eps):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
input = input * torch.rsqrt(variance + eps)
if weight is None:
return input
# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
input = input.to(weight.dtype)
return weight * input
class RMSNormTorch(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.empty(*normalized_shape))
self.reset_parameters()
def forward(self, input: torch.Tensor):
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
def reset_parameters(self):
init.ones_(self.weight)
def extra_repr(self):
return "{normalized_shape}, eps={eps}, ".format(**self.__dict__)

View File

@ -71,3 +71,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
def gather_forward_split_backward(input_, parallel_mode, dim): def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
def try_import_RMSNorm():
"""
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
"""
try:
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
return RMSNorm
except ModuleNotFoundError as e:
import warnings
warnings.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
from internlm.model.norm import RMSNormTorch as RMSNorm
return RMSNorm

View File

@ -5,10 +5,8 @@ import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Optional from typing import Dict, Optional
import amp_C
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from apex.multi_tensor_apply import multi_tensor_applier
from torch import Tensor from torch import Tensor
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
@ -156,17 +154,37 @@ def sync_param(flat_tensor, tensor_list):
for p, q in zip(tensor_list, updated_params): for p, q in zip(tensor_list, updated_params):
p.data = q.data p.data = q.data
def multi_tensor_l2norm_torch(tensor_list, per_tensor):
# Convert tensor_list elements to torch.float32
tensor_list = [tensor.float() for tensor in tensor_list]
norms_tensor = torch.stack([torch.norm(tensor, p=2) for tensor in tensor_list])
l2_norm = torch.norm(norms_tensor, p=2).unsqueeze(0)
if per_tensor:
per_tensor_norm = norms_tensor
else:
per_tensor_norm = torch.Tensor([]).to(norms_tensor.device)
return l2_norm, per_tensor_norm
def calc_l2_norm(grads): def calc_l2_norm(grads):
norm = 0.0 norm = 0.0
if len(grads) > 0: if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0]) try:
norm, _ = multi_tensor_applier( import amp_C
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm from apex.multi_tensor_apply import multi_tensor_applier
)
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
)
except ModuleNotFoundError as e:
import warnings
warnings.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
norm, _ = multi_tensor_l2norm_torch(grads, False)
return norm return norm
def calc_lp(grads, norm_type): def calc_lp(grads, norm_type):
norm = 0.0 norm = 0.0
for grad in grads: for grad in grads: