mirror of https://github.com/hpcaitech/ColossalAI
fix colo parameter torch function (#1117)
parent
e1620ddac2
commit
f99f56dff4
|
@ -7,6 +7,23 @@ from colossalai.tensor.param_op_hook import ParamOpHookManager
|
|||
from typing import Optional
|
||||
|
||||
|
||||
def filter_args(func, *args):
|
||||
return [arg for arg in args if func(arg)]
|
||||
|
||||
|
||||
def unpack_args(*args):
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
|
||||
def replace_args(args, kwargs, new_args):
|
||||
args = new_args[:len(args)]
|
||||
for k, v in zip(kwargs.keys(), new_args[len(args):]):
|
||||
kwargs[k] = v
|
||||
return unpack_args(args), kwargs
|
||||
|
||||
|
||||
class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
r"""A kind of ColoTensor to be considered as a module parameter.
|
||||
|
||||
|
@ -50,12 +67,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||
if ParamOpHookManager.has_hook():
|
||||
if not func.__name__.startswith('__'):
|
||||
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
|
||||
if kwargs is not None:
|
||||
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
args = ParamOpHookManager.pre_op(params, *args)
|
||||
new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
args, kwargs = replace_args(args, kwargs, new_args)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = ParamOpHookManager.post_op(params, ret)
|
||||
|
|
Loading…
Reference in New Issue