ColossalAI/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py

7 lines
214 B
Python
Raw Normal View History

import torch
from ..registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.relu)
def torch_nn_func_relu(input, inplace=False):
return torch.empty(input.shape, device='meta')