import torch from ...registry import meta_patched_module @meta_patched_module.register(torch.nn.Linear) def torch_nn_linear(self, input): last_dim = input.shape[-1] assert ( last_dim == self.in_features ), f"Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch" return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")