ColossalAI/colossalai/fx/tracer/meta_patch/patched_module.py

8 lines
217 B
Python
Raw Normal View History

import torch
from .registry import meta_patched_module
@meta_patched_module.register(torch.nn.Linear)
def torch_nn_linear(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")