mirror of https://github.com/InternLM/InternLM
				
				
				
			feat(*): support no apex (#166)
* support no-apex * add default for use_apex * fix lint * modify the RMSNormTorch * remove some comments * remove use_apex parameter * remove some unnecessary codepull/170/head
							parent
							
								
									66a23e326a
								
							
						
					
					
						commit
						1c397f523f
					
				|  | @ -5,7 +5,6 @@ import math | |||
| from typing import Optional | ||||
| 
 | ||||
| import torch | ||||
| from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm | ||||
| from flash_attn.modules.embedding import ParallelGPT2Embeddings | ||||
| from flash_attn.modules.mlp import ParallelFusedMLP | ||||
| from torch import nn | ||||
|  | @ -20,7 +19,7 @@ from internlm.model.linear import ( | |||
|     ScaleColumnParallelLinear, | ||||
| ) | ||||
| from internlm.model.multi_head_attention import MHA | ||||
| from internlm.model.utils import gather_forward_split_backward | ||||
| from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm | ||||
| from internlm.solver.pipeline_utils import partition_uniform | ||||
| from internlm.utils.checkpoint import activation_checkpoint | ||||
| from internlm.utils.common import filter_kwargs | ||||
|  | @ -97,6 +96,7 @@ class PackedFlashBaseLayer1D(nn.Module): | |||
| 
 | ||||
|         self.dropout1 = nn.Dropout(drop_rate) | ||||
|         if norm_type == "rmsnorm": | ||||
|             RMSNorm = try_import_RMSNorm() | ||||
|             self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon) | ||||
|             self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon) | ||||
|         else: | ||||
|  | @ -335,6 +335,7 @@ class PackedFlashInternLm1D(nn.Module): | |||
|         ) | ||||
|         if last: | ||||
|             if norm_type == "rmsnorm": | ||||
|                 RMSNorm = try_import_RMSNorm() | ||||
|                 self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) | ||||
|             else: | ||||
|                 self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) | ||||
|  |  | |||
|  | @ -0,0 +1,44 @@ | |||
| # adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm | ||||
| 
 | ||||
| import numbers | ||||
| 
 | ||||
| import torch | ||||
| from torch.nn import init | ||||
| from torch.nn.parameter import Parameter | ||||
| 
 | ||||
| 
 | ||||
| def manual_rms_norm(input, normalized_shape, weight, eps): | ||||
|     # layer norm should always be calculated in float32 | ||||
|     dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) | ||||
|     variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True) | ||||
|     input = input * torch.rsqrt(variance + eps) | ||||
| 
 | ||||
|     if weight is None: | ||||
|         return input | ||||
| 
 | ||||
|     # convert into half-precision if necessary | ||||
|     if weight.dtype in [torch.float16, torch.bfloat16]: | ||||
|         input = input.to(weight.dtype) | ||||
| 
 | ||||
|     return weight * input | ||||
| 
 | ||||
| 
 | ||||
| class RMSNormTorch(torch.nn.Module): | ||||
|     def __init__(self, normalized_shape, eps=1e-5): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         if isinstance(normalized_shape, numbers.Integral): | ||||
|             normalized_shape = (normalized_shape,) | ||||
|         self.normalized_shape = torch.Size(normalized_shape) | ||||
|         self.eps = eps | ||||
|         self.weight = Parameter(torch.empty(*normalized_shape)) | ||||
|         self.reset_parameters() | ||||
| 
 | ||||
|     def forward(self, input: torch.Tensor): | ||||
|         return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) | ||||
| 
 | ||||
|     def reset_parameters(self): | ||||
|         init.ones_(self.weight) | ||||
| 
 | ||||
|     def extra_repr(self): | ||||
|         return "{normalized_shape}, eps={eps}, ".format(**self.__dict__) | ||||
|  | @ -71,3 +71,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function): | |||
| 
 | ||||
| def gather_forward_split_backward(input_, parallel_mode, dim): | ||||
|     return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) | ||||
| 
 | ||||
| 
 | ||||
| def try_import_RMSNorm(): | ||||
|     """ | ||||
|     Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm | ||||
|      | ||||
|     """ | ||||
|     try: | ||||
|         from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm | ||||
|         return RMSNorm | ||||
|     except ModuleNotFoundError as e: | ||||
|         import warnings | ||||
|         warnings.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") | ||||
|         from internlm.model.norm import RMSNormTorch as RMSNorm | ||||
|         return RMSNorm | ||||
|  | @ -5,10 +5,8 @@ import math | |||
| from abc import ABC, abstractmethod | ||||
| from typing import Dict, Optional | ||||
| 
 | ||||
| import amp_C | ||||
| import torch | ||||
| import torch.distributed as dist | ||||
| from apex.multi_tensor_apply import multi_tensor_applier | ||||
| from torch import Tensor | ||||
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | ||||
| 
 | ||||
|  | @ -156,17 +154,37 @@ def sync_param(flat_tensor, tensor_list): | |||
|     for p, q in zip(tensor_list, updated_params): | ||||
|         p.data = q.data | ||||
| 
 | ||||
| def multi_tensor_l2norm_torch(tensor_list, per_tensor): | ||||
|     # Convert tensor_list elements to torch.float32 | ||||
|     tensor_list = [tensor.float() for tensor in tensor_list] | ||||
|     norms_tensor = torch.stack([torch.norm(tensor, p=2) for tensor in tensor_list]) | ||||
|     l2_norm = torch.norm(norms_tensor, p=2).unsqueeze(0) | ||||
| 
 | ||||
|     if per_tensor: | ||||
|         per_tensor_norm = norms_tensor | ||||
|     else: | ||||
|         per_tensor_norm = torch.Tensor([]).to(norms_tensor.device) | ||||
| 
 | ||||
|     return l2_norm, per_tensor_norm | ||||
| 
 | ||||
| def calc_l2_norm(grads): | ||||
|     norm = 0.0 | ||||
|     if len(grads) > 0: | ||||
|         dummy_overflow_buf = torch.cuda.IntTensor([0]) | ||||
|         norm, _ = multi_tensor_applier( | ||||
|             amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False  # no per-parameter norm | ||||
|         ) | ||||
|         try: | ||||
|             import amp_C | ||||
|             from apex.multi_tensor_apply import multi_tensor_applier | ||||
|              | ||||
|             dummy_overflow_buf = torch.cuda.IntTensor([0]) | ||||
|             norm, _ = multi_tensor_applier( | ||||
|                 amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False  # no per-parameter norm | ||||
|             ) | ||||
|         except ModuleNotFoundError as e: | ||||
|             import warnings | ||||
|             warnings.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!") | ||||
|              | ||||
|             norm, _ = multi_tensor_l2norm_torch(grads, False) | ||||
|     return norm | ||||
| 
 | ||||
| 
 | ||||
| def calc_lp(grads, norm_type): | ||||
|     norm = 0.0 | ||||
|     for grad in grads: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 ytxiong
						ytxiong