mirror of https://github.com/hpcaitech/ColossalAI
25 lines
840 B
Python
25 lines
840 B
Python
|
import operator
|
||
|
import torch
|
||
|
from ..registry import meta_patched_function
|
||
|
|
||
|
|
||
|
@meta_patched_function.register(operator.getitem)
|
||
|
def operator_getitem(a, b):
|
||
|
# copied from huggingface.utils.fx
|
||
|
def to_concrete(t):
|
||
|
if isinstance(t, torch.Tensor):
|
||
|
concrete = torch.ones_like(t, device="cpu")
|
||
|
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
|
||
|
concrete = concrete.to(torch.int64)
|
||
|
return concrete
|
||
|
return t
|
||
|
|
||
|
if isinstance(a, torch.Tensor):
|
||
|
# TODO: infer shape without performing the computation.
|
||
|
if isinstance(b, tuple):
|
||
|
b = tuple(map(to_concrete, b))
|
||
|
else:
|
||
|
b = to_concrete(b)
|
||
|
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
||
|
return operator.getitem(a, b)
|