import functools
from typing import Callable, Dict

# Custom sharded ops
_COLOSSAL_OPS: Dict[str, Callable] = {}


def _register_colo_op(op, func):
    global _COLOSSAL_OPS
    _COLOSSAL_OPS[op] = func


def colo_op_impl(func):
    """
    Provides a way for users to write their own custom operator. This
    can be used to override existing ColoTensor operators or write a new
    one not supported by ColoTensor. If the operator in question is covered
    by ``__torch_function__`` dispatch and has a ColoTensor as any of its
    parameters, the function provided will be invoked for that operator.

    Example:
        >>> @colo_op_impl(torch.nn.functional.linear)
        >>> def my_custom_linear(types, args, kwargs, process_group):
        >>>   ....
        >>>
        >>> input = torch.rand(10, 32)
        >>> weight = ColoTensor(torch.rand(32, 16))
        >>> bias = ColoTensor(torch.rand(16))
        >>> # This will call `my_custom_linear` instead of the default.
        >>> torch.nn.functional.linear(input, weight, bias)

    The types, args and kwargs parameters are the same parameters that are
    passed to ``__torch_function__`` dispatch API
    (https://pytorch.org/docs/stable/notes/extending.html#extending-torch).

    Args:
        func(Callable): Torch function for which we want to provide a sharded
            implementation (ex: torch.nn.functional.linear)
    """

    def decorator_sharded_func(wrapped_func):
        _register_colo_op(func, wrapped_func)

        @functools.wraps(wrapped_func)
        def wrapper(*args, **kwargs):
            return wrapped_func(*args, **kwargs)

        return wrapper

    return decorator_sharded_func