mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
1.4 KiB
51 lines
1.4 KiB
from typing import Any, List, Union |
|
|
|
import torch |
|
|
|
from ..proxy import ColoProxy |
|
from .meta_patch import meta_patched_function |
|
|
|
__all__ = ["is_element_in_list", "extract_meta"] |
|
|
|
|
|
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): |
|
if isinstance(elements, (tuple, list, set)): |
|
for ele in elements: |
|
if ele not in list_: |
|
return False, ele |
|
else: |
|
if elements not in list_: |
|
return False, elements |
|
|
|
return True, None |
|
|
|
|
|
def extract_meta(*args, **kwargs): |
|
def _convert(val): |
|
if isinstance(val, ColoProxy): |
|
return val.meta_data |
|
elif isinstance(val, (list, tuple)): |
|
return type(val)([_convert(ele) for ele in val]) |
|
|
|
return val |
|
|
|
new_args = [_convert(val) for val in args] |
|
new_kwargs = {k: _convert(v) for k, v in kwargs.items()} |
|
return new_args, new_kwargs |
|
|
|
|
|
def compute_meta_data_for_functions_proxy(target, args, kwargs): |
|
args_metas, kwargs_metas = extract_meta(*args, **kwargs) |
|
|
|
# fetch patched function |
|
if meta_patched_function.has(target): |
|
meta_target = meta_patched_function.get(target) |
|
elif meta_patched_function.has(target.__name__): |
|
meta_target = meta_patched_function.get(target.__name__) |
|
else: |
|
meta_target = target |
|
meta_out = meta_target(*args_metas, **kwargs_metas) |
|
if isinstance(meta_out, torch.Tensor): |
|
meta_out = meta_out.to(device="meta") |
|
|
|
return meta_out
|
|
|