mirror of https://github.com/hpcaitech/ColossalAI
[fx] fixed tracing with apex-based T5 model (#1252)
* [fx] fixed tracing with apex-based T5 model * polish code * polish codepull/1271/head
parent
7531c6271f
commit
4a09fc0947
|
@ -7,6 +7,7 @@ tracer.py:
|
|||
import enum
|
||||
import inspect
|
||||
import functools
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
@ -181,7 +182,16 @@ class ColoTracer(Tracer):
|
|||
|
||||
def call_module(self, m, forward, args, kwargs):
|
||||
self.orig_forward = forward
|
||||
return super().call_module(m, forward, args, kwargs)
|
||||
module_qualified_name = self.path_of_module(m)
|
||||
|
||||
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
|
||||
# which means customized modules are not leaf module by default
|
||||
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
|
||||
# we should treat it as leaf module as well
|
||||
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
|
||||
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
|
||||
else:
|
||||
return forward(*args, **kwargs)
|
||||
|
||||
def proxy(self, node) -> Proxy:
|
||||
"""
|
||||
|
|
|
@ -1,8 +1,18 @@
|
|||
import pytest
|
||||
import transformers
|
||||
import torch
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||
from utils import trace_model_and_compare_output
|
||||
|
||||
try:
|
||||
import apex
|
||||
|
||||
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
|
||||
def apex_fused_layernorm(self, input):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
|
Loading…
Reference in New Issue