mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
53 lines
1.7 KiB
53 lines
1.7 KiB
from typing import ( |
|
Callable, |
|
Dict, |
|
) |
|
import functools |
|
|
|
# Custom sharded ops |
|
_COLOSSAL_OPS: Dict[str, Callable] = {} |
|
|
|
|
|
def _register_colo_op(op, func): |
|
global _COLOSSAL_OPS |
|
_COLOSSAL_OPS[op] = func |
|
|
|
|
|
def colo_op_impl(func): |
|
""" |
|
Provides a way for users to write their own custom operator. This |
|
can be used to override existing ColoTensor operators or write a new |
|
one not supported by ColoTensor. If the operator in question is covered |
|
by ``__torch_function__`` dispatch and has a ColoTensor as any of its |
|
parameters, the function provided will be invoked for that operator. |
|
|
|
Example: |
|
>>> @colo_op_impl(torch.nn.functional.linear) |
|
>>> def my_custom_linear(types, args, kwargs, process_group): |
|
>>> .... |
|
>>> |
|
>>> input = torch.rand(10, 32) |
|
>>> weight = ColoTensor(torch.rand(32, 16)) |
|
>>> bias = ColoTensor(torch.rand(16)) |
|
>>> # This will call `my_custom_linear` instead of the default. |
|
>>> torch.nn.functional.linear(input, weight, bias) |
|
|
|
The types, args and kwargs parameters are the same parameters that are |
|
passed to ``__torch_function__`` dispatch API |
|
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch). |
|
|
|
Args: |
|
func(Callable): Torch function for which we want to provide a sharded |
|
implementation (ex: torch.nn.functional.linear) |
|
""" |
|
|
|
def decorator_sharded_func(wrapped_func): |
|
_register_colo_op(func, wrapped_func) |
|
|
|
@functools.wraps(wrapped_func) |
|
def wrapper(*args, **kwargs): |
|
return wrapped_func(*args, **kwargs) |
|
|
|
return wrapper |
|
|
|
return decorator_sharded_func
|
|
|