[fx] added apex normalization to patched modules (#1300)

* [fx] added apex normalization to patched modules

* remove unused imports
pull/1309/head
Frank Lee 2022-07-14 14:24:13 +08:00 committed by GitHub
parent 4165eabb1e
commit 4f4d8c3656
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 20 deletions

View File

@ -17,4 +17,14 @@ def torch_nn_normalize(self, input):
assert input.dim() == 5
# normalization maintain the same shape as the input
return input.clone()
return input.clone()
try:
import apex
meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize)
except ImportError:
pass

View File

@ -2,15 +2,6 @@ import pytest
import transformers
import torch
from hf_utils import split_model_and_compare_output
from colossalai.fx.tracer.meta_patch import meta_patched_module
try:
import apex
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
def apex_fused_layernorm(self, input):
return torch.empty(input.shape, device='meta')
except ImportError:
pass
BATCH_SIZE = 1
SEQ_LENGHT = 16

View File

@ -1,18 +1,8 @@
import pytest
import transformers
import torch
from colossalai.fx.tracer.meta_patch import meta_patched_module
from utils import trace_model_and_compare_output
try:
import apex
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
def apex_fused_layernorm(self, input):
return torch.empty(input.shape, device='meta')
except ImportError:
pass
BATCH_SIZE = 1
SEQ_LENGHT = 16