import operator import torch from ..registry import meta_patched_function @meta_patched_function.register(operator.getitem) def operator_getitem(a, b): # copied from huggingface.utils.fx def to_concrete(t): if isinstance(t, torch.Tensor): concrete = torch.ones_like(t, device="cpu") if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]: concrete = concrete.to(torch.int64) return concrete return t if isinstance(a, torch.Tensor): # TODO: infer shape without performing the computation. if isinstance(b, tuple): b = tuple(map(to_concrete, b)) else: b = to_concrete(b) return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta") return operator.getitem(a, b)