You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/tensor/op_wrapper.py

54 lines
1.7 KiB

from typing import (
Callable,
Dict,
)
import functools
# Custom sharded ops
_COLOSSAL_OPS: Dict[str, Callable] = {}
def _register_colo_op(op, func):
global _COLOSSAL_OPS
_COLOSSAL_OPS[op] = func
def colo_op_impl(func):
"""
Provides a way for users to write their own custom operator. This
can be used to override existing ColoTensor operators or write a new
one not supported by ColoTensor. If the operator in question is covered
by ``__torch_function__`` dispatch and has a ColoTensor as any of its
parameters, the function provided will be invoked for that operator.
Example:
>>> @colo_op_impl(torch.nn.functional.linear)
>>> def my_custom_linear(types, args, kwargs, process_group):
>>> ....
>>>
>>> input = torch.rand(10, 32)
>>> weight = ColoTensor(torch.rand(32, 16))
>>> bias = ColoTensor(torch.rand(16))
>>> # This will call `my_custom_linear` instead of the default.
>>> torch.nn.functional.linear(input, weight, bias)
The types, args and kwargs parameters are the same parameters that are
passed to ``__torch_function__`` dispatch API
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
Args:
func(Callable): Torch function for which we want to provide a sharded
implementation (ex: torch.nn.functional.linear)
"""
def decorator_sharded_func(wrapped_func):
_register_colo_op(func, wrapped_func)
@functools.wraps(wrapped_func)
def wrapper(*args, **kwargs):
return wrapped_func(*args, **kwargs)
return wrapper
return decorator_sharded_func