mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.8 KiB
50 lines
1.8 KiB
import functools
|
|
import torch
|
|
from colossalai.tensor import ColoTensor
|
|
from typing import Callable, List
|
|
from colossalai.nn._ops._utils import convert_to_colo_tensor
|
|
|
|
|
|
def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
|
|
"""register_colo_graph
|
|
Register a Op (Layer) to ColoGraph.
|
|
Recoders the input args in types of ColoTensor to the Graph.
|
|
Args:
|
|
func (Callable): a function implements the Op.
|
|
|
|
Returns:
|
|
Callable: wrapper function.
|
|
"""
|
|
|
|
def register_colo_graph_decorator(func):
|
|
from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
param_list = []
|
|
input_list = []
|
|
for idx, arg in enumerate(args):
|
|
if isinstance(arg, torch.Tensor) and idx in input_pos:
|
|
input_list.append(convert_to_colo_tensor(arg))
|
|
if isinstance(arg, torch.Tensor) and idx in param_pos:
|
|
param_list.append(convert_to_colo_tensor(arg))
|
|
# building the computing graph, inputs -> op
|
|
if GraphGlobalEnv().graph_building:
|
|
cur_op_node = GraphOpNode('linear', param_list)
|
|
# TODO supports a list of ColoTensor as args
|
|
if len(input_list) > 0:
|
|
cur_op_node.add_prev_tensor(input_list[0])
|
|
|
|
outputs = func(*args, **kwargs)
|
|
|
|
# building the computing graph, op -> output
|
|
if GraphGlobalEnv().graph_building:
|
|
# TODO supports a list of ColoTensor as args
|
|
if isinstance(outputs[0], ColoTensor):
|
|
cur_op_node.add_post_tensor(outputs[0])
|
|
return outputs
|
|
|
|
return wrapper
|
|
|
|
return register_colo_graph_decorator
|