ColossalAI/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py

58 lines
2.0 KiB
Python
Raw Normal View History

import operator
import torch
from ..registry import meta_patched_function
2022-07-19 07:50:42 +00:00
from colossalai.fx.proxy import ColoProxy
@meta_patched_function.register(operator.getitem)
def operator_getitem(a, b):
# copied from huggingface.utils.fx
def to_concrete(t):
if isinstance(t, torch.Tensor):
concrete = torch.ones_like(t, device="cpu")
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
concrete = concrete.to(torch.int64)
return concrete
return t
2022-07-19 07:50:42 +00:00
def _slice_convert(slice_obj):
attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step}
new_attrs = _slice_attr_convert(attrs)
attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step'])
return slice(*attr_dict_to_tuple)
def _slice_attr_convert(attrs):
new_attrs = {}
for key, value in attrs.items():
if isinstance(value, ColoProxy):
new_attrs[key] = value.meta_data
else:
new_attrs[key] = value
return new_attrs
if isinstance(b, tuple):
b = list(b)
for index, element in enumerate(b):
if isinstance(element, slice):
b[index] = _slice_convert(element)
b = tuple(b)
elif isinstance(b, slice):
b = _slice_convert(b)
if isinstance(a, torch.Tensor):
# TODO: infer shape without performing the computation.
if isinstance(b, tuple):
b = tuple(map(to_concrete, b))
else:
b = to_concrete(b)
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
2022-07-19 07:50:42 +00:00
if isinstance(a, ColoProxy):
# TODO: infer shape without performing the computation.
if isinstance(b, tuple):
b = tuple(map(to_concrete, b))
else:
b = to_concrete(b)
return operator.getitem(torch.empty_like(a.meta_data, device="cpu"), b).to("meta")
return operator.getitem(a, b)