mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
1.5 KiB
51 lines
1.5 KiB
from typing import List, Union, Any
|
|
from ..proxy import ColoProxy, ColoAttribute
|
|
import torch
|
|
from .meta_patch import meta_patched_function, meta_patched_module
|
|
|
|
__all__ = ['is_element_in_list', 'extract_meta']
|
|
|
|
|
|
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
|
|
if isinstance(elements, (tuple, list, set)):
|
|
for ele in elements:
|
|
if ele not in list_:
|
|
return False, ele
|
|
else:
|
|
if elements not in list_:
|
|
return False, elements
|
|
|
|
return True, None
|
|
|
|
|
|
def extract_meta(*args, **kwargs):
|
|
|
|
def _convert(val):
|
|
if isinstance(val, ColoProxy):
|
|
return val.meta_data
|
|
elif isinstance(val, (list, tuple)):
|
|
return type(val)([_convert(ele) for ele in val])
|
|
|
|
return val
|
|
|
|
new_args = [_convert(val) for val in args]
|
|
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
|
|
return new_args, new_kwargs
|
|
|
|
|
|
def compute_meta_data_for_functions_proxy(target, args, kwargs):
|
|
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
|
|
|
|
# fetch patched function
|
|
if meta_patched_function.has(target):
|
|
meta_target = meta_patched_function.get(target)
|
|
elif meta_patched_function.has(target.__name__):
|
|
meta_target = meta_patched_function.get(target.__name__)
|
|
else:
|
|
meta_target = target
|
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
|
if isinstance(meta_out, torch.Tensor):
|
|
meta_out = meta_out.to(device="meta")
|
|
|
|
return meta_out
|