ColossalAI/colossalai/tensor/colo_parameter.py

94 lines
3.3 KiB
Python
Raw Normal View History

from typing import Optional
import torch
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
from .colo_tensor import _convert_output
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
def is_no_hook_op(func) -> bool:
return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS
2022-05-06 04:57:14 +00:00
def filter_colo_parameters(*args, **kwargs):
param_list = []
def get_colo_parameters(element) -> None:
if isinstance(element, list) or isinstance(element, tuple):
for e in element:
get_colo_parameters(e)
elif isinstance(element, dict):
raise RuntimeError("Found Dict: ColoParameter can't deal with complicated arguments.")
elif isinstance(element, ColoParameter):
param_list.append(element)
return
for a in args:
get_colo_parameters(a)
for v in kwargs.values():
get_colo_parameters(v)
return param_list
def replace_args(args, kwargs, new_args):
args = new_args[: len(args)]
for k, v in zip(kwargs.keys(), new_args[len(args) :]):
kwargs[k] = v
return tuple(args), kwargs
class ColoParameter(ColoTensor, torch.nn.Parameter):
r"""A kind of ColoTensor to be considered as a module parameter."""
2022-05-06 04:57:14 +00:00
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> "ColoParameter":
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
2022-05-06 04:57:14 +00:00
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
if kwargs is None:
kwargs = {}
if ColoParamOpHookManager.has_hook() and not is_no_hook_op(func):
params = filter_colo_parameters(*args, **kwargs)
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ColoParamOpHookManager.post_op(params, ret)
return _convert_output(ret, func)
return super().__torch_function__(func, types, args, kwargs)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
tensor = ColoParameter(data, self.requires_grad)
memo[id(self)] = tensor
return tensor
def __reduce_ex__(self, proto):
# Adapted from torch._utils._rebuild_parameter
# def _rebuild_colo_parameter(data, requires_grad, backward_hooks):
# colo_param = ColoParameter(data, requires_grad)
# colo_param._backward_hooks = backward_hooks
# return colo_param
# return (
# _rebuild_colo_parameter,
# (self.data, self.requires_grad, OrderedDict())
# )
# TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise NotImplementedError