mirror of https://github.com/hpcaitech/ColossalAI
61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
import operator
|
|
|
|
import torch
|
|
|
|
from colossalai.fx.proxy import ColoProxy
|
|
|
|
from ...registry import meta_patched_function
|
|
|
|
|
|
@meta_patched_function.register(operator.getitem)
|
|
def operator_getitem(a, b):
|
|
# copied from huggingface.utils.fx
|
|
def to_concrete(t):
|
|
if isinstance(t, torch.Tensor):
|
|
concrete = torch.ones_like(t, device="cpu")
|
|
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
|
|
concrete = concrete.to(torch.int64)
|
|
return concrete
|
|
return t
|
|
|
|
def _slice_convert(slice_obj):
|
|
attrs = {"start": slice_obj.start, "stop": slice_obj.stop, "step": slice_obj.step}
|
|
new_attrs = _slice_attr_convert(attrs)
|
|
attr_dict_to_tuple = (new_attrs["start"], new_attrs["stop"], new_attrs["step"])
|
|
return slice(*attr_dict_to_tuple)
|
|
|
|
def _slice_attr_convert(attrs):
|
|
new_attrs = {}
|
|
for key, value in attrs.items():
|
|
if isinstance(value, ColoProxy):
|
|
new_attrs[key] = value.meta_data
|
|
else:
|
|
new_attrs[key] = value
|
|
return new_attrs
|
|
|
|
if isinstance(b, tuple):
|
|
b = list(b)
|
|
for index, element in enumerate(b):
|
|
if isinstance(element, slice):
|
|
b[index] = _slice_convert(element)
|
|
b = tuple(b)
|
|
elif isinstance(b, slice):
|
|
b = _slice_convert(b)
|
|
|
|
if isinstance(a, torch.Tensor):
|
|
# TODO: infer shape without performing the computation.
|
|
if isinstance(b, tuple):
|
|
b = tuple(map(to_concrete, b))
|
|
else:
|
|
b = to_concrete(b)
|
|
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
|
|
|
if isinstance(a, ColoProxy):
|
|
# TODO: infer shape without performing the computation.
|
|
if isinstance(b, tuple):
|
|
b = tuple(map(to_concrete, b))
|
|
else:
|
|
b = to_concrete(b)
|
|
return operator.getitem(torch.empty_like(a.meta_data, device="cpu"), b).to("meta")
|
|
return operator.getitem(a, b)
|