mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
52 lines
1.4 KiB
52 lines
1.4 KiB
from typing import Any, List, Union
|
|
|
|
import torch
|
|
|
|
from ..proxy import ColoProxy
|
|
from .meta_patch import meta_patched_function
|
|
|
|
__all__ = ["is_element_in_list", "extract_meta"]
|
|
|
|
|
|
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
|
|
if isinstance(elements, (tuple, list, set)):
|
|
for ele in elements:
|
|
if ele not in list_:
|
|
return False, ele
|
|
else:
|
|
if elements not in list_:
|
|
return False, elements
|
|
|
|
return True, None
|
|
|
|
|
|
def extract_meta(*args, **kwargs):
|
|
def _convert(val):
|
|
if isinstance(val, ColoProxy):
|
|
return val.meta_data
|
|
elif isinstance(val, (list, tuple)):
|
|
return type(val)([_convert(ele) for ele in val])
|
|
|
|
return val
|
|
|
|
new_args = [_convert(val) for val in args]
|
|
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
|
|
return new_args, new_kwargs
|
|
|
|
|
|
def compute_meta_data_for_functions_proxy(target, args, kwargs):
|
|
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
|
|
|
|
# fetch patched function
|
|
if meta_patched_function.has(target):
|
|
meta_target = meta_patched_function.get(target)
|
|
elif meta_patched_function.has(target.__name__):
|
|
meta_target = meta_patched_function.get(target.__name__)
|
|
else:
|
|
meta_target = target
|
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
|
if isinstance(meta_out, torch.Tensor):
|
|
meta_out = meta_out.to(device="meta")
|
|
|
|
return meta_out
|