mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
3.3 KiB
93 lines
3.3 KiB
from typing import Optional |
|
|
|
import torch |
|
|
|
from colossalai.tensor.colo_tensor import ColoTensor |
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager |
|
|
|
from .colo_tensor import _convert_output |
|
|
|
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__} |
|
|
|
|
|
def is_no_hook_op(func) -> bool: |
|
return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS |
|
|
|
|
|
def filter_colo_parameters(*args, **kwargs): |
|
param_list = [] |
|
|
|
def get_colo_parameters(element) -> None: |
|
if isinstance(element, list) or isinstance(element, tuple): |
|
for e in element: |
|
get_colo_parameters(e) |
|
elif isinstance(element, dict): |
|
raise RuntimeError("Found Dict: ColoParameter can't deal with complicated arguments.") |
|
elif isinstance(element, ColoParameter): |
|
param_list.append(element) |
|
return |
|
|
|
for a in args: |
|
get_colo_parameters(a) |
|
for v in kwargs.values(): |
|
get_colo_parameters(v) |
|
|
|
return param_list |
|
|
|
|
|
def replace_args(args, kwargs, new_args): |
|
args = new_args[: len(args)] |
|
for k, v in zip(kwargs.keys(), new_args[len(args) :]): |
|
kwargs[k] = v |
|
return tuple(args), kwargs |
|
|
|
|
|
class ColoParameter(ColoTensor, torch.nn.Parameter): |
|
r"""A kind of ColoTensor to be considered as a module parameter.""" |
|
|
|
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> "ColoParameter": |
|
if data is None: |
|
data = torch.empty(0) |
|
return torch.Tensor._make_subclass(cls, data, requires_grad) |
|
|
|
@classmethod |
|
def __torch_function__(cls, func, types, args=..., kwargs=None): |
|
if kwargs is None: |
|
kwargs = {} |
|
if ColoParamOpHookManager.has_hook() and not is_no_hook_op(func): |
|
params = filter_colo_parameters(*args, **kwargs) |
|
if len(params) > 0: |
|
with torch._C.DisableTorchFunction(): |
|
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) |
|
args, kwargs = replace_args(args, kwargs, new_args) |
|
ret = super().__torch_function__(func, types, args, kwargs) |
|
with torch._C.DisableTorchFunction(): |
|
ret = ColoParamOpHookManager.post_op(params, ret) |
|
return _convert_output(ret, func) |
|
return super().__torch_function__(func, types, args, kwargs) |
|
|
|
def __deepcopy__(self, memo): |
|
if id(self) in memo: |
|
return memo[id(self)] |
|
else: |
|
with torch._C.DisableTorchFunction(): |
|
data = self.data.clone() |
|
tensor = ColoParameter(data, self.requires_grad) |
|
memo[id(self)] = tensor |
|
return tensor |
|
|
|
def __reduce_ex__(self, proto): |
|
# Adapted from torch._utils._rebuild_parameter |
|
# def _rebuild_colo_parameter(data, requires_grad, backward_hooks): |
|
# colo_param = ColoParameter(data, requires_grad) |
|
# colo_param._backward_hooks = backward_hooks |
|
# return colo_param |
|
|
|
# return ( |
|
# _rebuild_colo_parameter, |
|
# (self.data, self.requires_grad, OrderedDict()) |
|
# ) |
|
|
|
# TODO(jzy) we don't support object reflection now. |
|
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`. |
|
raise NotImplementedError
|
|
|