#!/usr/bin/env python # -*- encoding: utf-8 -*- import torch import torch.nn as nn from colossalai.lazy import LazyInitContext __all__ = ['FusedLayerNorm', 'FusedRMSNorm'] FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, 25600, 30720, 32768, 40960, 49152, 65536 ] class FusedLayerNorm(): r""" This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. """ def __init__(self) -> None: raise NotImplementedError( 'FusedLayerNorm is not implemented as a physical class. ' 'It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex.' ) @staticmethod def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: r""" Convert a native pytorch layer norm module to colossalai layer norm module """ # check if apex is installed try: import apex except ImportError: raise ImportError( 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel') LazyInitContext.materialize(module) # get the attributes of the module normalized_shape = module.normalized_shape eps = module.eps elementwise_affine = module.elementwise_affine dtype = module.weight.dtype device = module.weight.device # pick the suitable layernorm implementation use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE if use_fast_ln: try: from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm except ImportError: # fall back to the normal fused layernorm is not built from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm else: from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) layernorm.weight = module.weight layernorm.bias = module.bias return layernorm class FusedRMSNorm(): """ This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. """ def __init__(self) -> None: raise NotImplementedError( 'FusedRMSNorm is not implemented as a physical class. ' 'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.' ) @staticmethod def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: try: from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm except ImportError: raise ImportError( 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel' ) LazyInitContext.materialize(module) # to check if it is huggingface LlamaRMSNorm if module.__class__.__name__ == "LlamaRMSNorm": normalized_shape = module.weight.shape[0] eps = module.variance_epsilon elementwise_affine = True else: # get the attributes of the module normalized_shape = module.normalized_shape eps = module.eps elementwise_affine = module.elementwise_affine rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) rmsnorm.weight = module.weight return rmsnorm