fix colo parameter torch function (#1117)

pull/1120/head
ver217 2022-06-15 14:23:27 +08:00 committed by GitHub
parent e1620ddac2
commit f99f56dff4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 22 additions and 4 deletions

View File

@ -7,6 +7,23 @@ from colossalai.tensor.param_op_hook import ParamOpHookManager
from typing import Optional
def filter_args(func, *args):
return [arg for arg in args if func(arg)]
def unpack_args(*args):
if len(args) == 1:
return args[0]
return args
def replace_args(args, kwargs, new_args):
args = new_args[:len(args)]
for k, v in zip(kwargs.keys(), new_args[len(args):]):
kwargs[k] = v
return unpack_args(args), kwargs
class ColoParameter(ColoTensor, torch.nn.Parameter):
r"""A kind of ColoTensor to be considered as a module parameter.
@ -50,12 +67,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
def __torch_function__(cls, func, types, args=..., kwargs=None):
if ParamOpHookManager.has_hook():
if not func.__name__.startswith('__'):
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
if kwargs is not None:
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
if kwargs is None:
kwargs = {}
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
if len(params) > 0:
with torch._C.DisableTorchFunction():
args = ParamOpHookManager.pre_op(params, *args)
new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ParamOpHookManager.post_op(params, ret)