ColossalAI/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py

9 lines
217 B
Python

import torch
from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.relu)
def torch_nn_func_relu(input, inplace=False):
return torch.empty(input.shape, device='meta')