from typing import List, Union, Any
from ..proxy import ColoProxy, ColoAttribute
import torch
from .meta_patch import meta_patched_function, meta_patched_module

__all__ = ['is_element_in_list', 'extract_meta']


def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
    if isinstance(elements, (tuple, list, set)):
        for ele in elements:
            if ele not in list_:
                return False, ele
    else:
        if elements not in list_:
            return False, elements

    return True, None


def extract_meta(*args, **kwargs):

    def _convert(val):
        if isinstance(val, ColoProxy):
            return val.meta_data
        elif isinstance(val, (list, tuple)):
            return type(val)([_convert(ele) for ele in val])

        return val

    new_args = [_convert(val) for val in args]
    new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
    return new_args, new_kwargs


def compute_meta_data_for_functions_proxy(target, args, kwargs):
    args_metas, kwargs_metas = extract_meta(*args, **kwargs)

    # fetch patched function
    if meta_patched_function.has(target):
        meta_target = meta_patched_function.get(target)
    elif meta_patched_function.has(target.__name__):
        meta_target = meta_patched_function.get(target.__name__)
    else:
        meta_target = target
    meta_out = meta_target(*args_metas, **kwargs_metas)
    if isinstance(meta_out, torch.Tensor):
        meta_out = meta_out.to(device="meta")

    return meta_out