diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py index 21c5ff280..eccb1b467 100644 --- a/colossalai/nn/_ops/linear.py +++ b/colossalai/nn/_ops/linear.py @@ -1,14 +1,14 @@ import torch.nn.functional as F from typing import Optional +from ._utils import GeneralTensor, convert_to_colo_tensor from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec -from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv from colossalai.context import ParallelMode -from ._utils import GeneralTensor, convert_to_colo_tensor +from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv -def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: +def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': # Input:S[1] x Weight:S[0] = Output:P # All-Reduce(Output) + bias = res # Input:S[1] @@ -28,7 +28,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option return output -def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: +def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] # All-Gather(Output) # Input:B @@ -48,23 +48,21 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option return output -def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: +def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': assert mode in ('row', 'col') funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol} return funcs[mode](input_tensor, weight, bias) -@colo_op_impl(F.linear) -def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None): +@register_colo_graph(input_pos=[1], param_pos=[2, 3]) +def colo_linear_imp(input_tensor: GeneralTensor, + weight: GeneralTensor, + bias: Optional[GeneralTensor] = None) -> 'ColoTensor': """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. This method computes a linear. """ input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) - # building the computing graph, inputs -> op - if GraphGlobalEnv().graph_building: - cur_op_node = GraphOpNode('linear', [weight, bias]) - cur_op_node.add_prev_tensor(input_tensor) # Add communication logic before and after linear call. ret_tensor = None if not weight.has_spec(): # No Model Parallel Applied @@ -82,7 +80,11 @@ def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Option else: raise NotImplementedError - # building the computing graph, op -> output - if GraphGlobalEnv().graph_building: - cur_op_node.add_post_tensor(ret_tensor) return ret_tensor + + +@colo_op_impl(F.linear) +def colo_linear(input_tensor: GeneralTensor, + weight: GeneralTensor, + bias: Optional[GeneralTensor] = None) -> 'ColoTensor': + return colo_linear_imp(input_tensor, weight, bias) diff --git a/colossalai/nn/graph/__init__.py b/colossalai/nn/graph/__init__.py new file mode 100644 index 000000000..0cfecf8b4 --- /dev/null +++ b/colossalai/nn/graph/__init__.py @@ -0,0 +1,4 @@ +from .utils import register_colo_graph +from .graph_node import GraphContext, GraphGlobalEnv, GraphOpNode + +__all__ = ['register_colo_graph', 'GraphContext', 'GraphGlobalEnv', 'GraphOpNode'] \ No newline at end of file diff --git a/colossalai/tensor/graph/graph_node.py b/colossalai/nn/graph/graph_node.py similarity index 97% rename from colossalai/tensor/graph/graph_node.py rename to colossalai/nn/graph/graph_node.py index d637d2bed..32653ad98 100644 --- a/colossalai/tensor/graph/graph_node.py +++ b/colossalai/nn/graph/graph_node.py @@ -74,7 +74,6 @@ class GraphOpNode(GraphNode): assert isinstance(colo_tensor, ColoTensor) if colo_tensor._graph_node is None: colo_tensor._graph_node = GraphNode() - prev_ops = colo_tensor._graph_node.prev_nodes for op_node in prev_ops: self.add_prev_node(op_node) @@ -85,7 +84,7 @@ class GraphOpNode(GraphNode): Op <- Activation (colo_tensor) """ if GraphGlobalEnv().graph_building: - assert isinstance(colo_tensor, ColoTensor) + assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}' if colo_tensor._graph_node is None: colo_tensor._graph_node = GraphNode() diff --git a/colossalai/nn/graph/utils.py b/colossalai/nn/graph/utils.py new file mode 100644 index 000000000..9218bf994 --- /dev/null +++ b/colossalai/nn/graph/utils.py @@ -0,0 +1,50 @@ +import functools +import torch +from colossalai.tensor import ColoTensor +from typing import Callable, List +from colossalai.nn._ops._utils import convert_to_colo_tensor + + +def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable: + """register_colo_graph + Register a Op (Layer) to ColoGraph. + Recoders the input args in types of ColoTensor to the Graph. + Args: + func (Callable): a function implements the Op. + + Returns: + Callable: wrapper function. + """ + + def register_colo_graph_decorator(func): + from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv + + @functools.wraps(func) + def wrapper(*args, **kwargs): + param_list = [] + input_list = [] + for idx, arg in enumerate(args): + if isinstance(arg, torch.Tensor) and idx in input_pos: + input_list.append(convert_to_colo_tensor(arg)) + if isinstance(arg, torch.Tensor) and idx in param_pos: + param_list.append(convert_to_colo_tensor(arg)) + print(f'Op {func}') + # building the computing graph, inputs -> op + if GraphGlobalEnv().graph_building: + cur_op_node = GraphOpNode('linear', param_list) + # TODO supports a list of ColoTensor as args + if len(input_list) > 0: + cur_op_node.add_prev_tensor(input_list[0]) + + outputs = func(*args, **kwargs) + + # building the computing graph, op -> output + if GraphGlobalEnv().graph_building: + # TODO supports a list of ColoTensor as args + if isinstance(outputs[0], ColoTensor): + cur_op_node.add_post_tensor(outputs[0]) + return outputs + + return wrapper + + return register_colo_graph_decorator diff --git a/colossalai/tensor/distspec.py b/colossalai/tensor/distspec.py index 4ded8fe45..99f041fd3 100644 --- a/colossalai/tensor/distspec.py +++ b/colossalai/tensor/distspec.py @@ -17,6 +17,13 @@ class _DistSpec: dist_placement_pattern: DistPlacementPattern, process_group: Optional[ProcessGroup] = None, **meta_info): + """_DistSpec, Distributed Specification + + Args: + dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes. + The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard. + process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None. + """ self.placement = dist_placement_pattern self.process_group = process_group for k, v in meta_info.items(): @@ -37,6 +44,7 @@ class _DistSpec: res += f'{attr}: {str(getattr(self, attr))}\n\t' return res + def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec: # process_group=None means global process group return _DistSpec(DistPlacementPattern.REPLICATE, process_group) diff --git a/colossalai/tensor/graph/__init__.py b/colossalai/tensor/graph/__init__.py deleted file mode 100644 index b0788062f..000000000 --- a/colossalai/tensor/graph/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .graph_node import GraphNode, GraphOpNode, GraphContext, GraphGlobalEnv - -__all__ = ['GraphNode', 'GraphOpNode', 'GraphContext', 'GraphGlobalEnv'] diff --git a/tests/test_tensor/test_graph.py b/tests/test_tensor/test_graph.py deleted file mode 100644 index 1b5505c07..000000000 --- a/tests/test_tensor/test_graph.py +++ /dev/null @@ -1,84 +0,0 @@ -import pytest -from torch import nn -import torch -from colossalai.tensor import ColoTensor -from colossalai.tensor.graph import GraphContext -import gc - - -class SimpleNet(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.proj1 = nn.Linear(4, 8) - self.proj2 = nn.Linear(8, 4) - self.proj3 = nn.Linear(4, 4) - self.proj4 = nn.Linear(4, 4) - - def forward(self, x): - x = self.proj1(x) - x = self.proj2(x) - x = self.proj3(x) - x = self.proj4(x) - return x - - -def _visit_graph(start_node): - if start_node is None: - return - - start_node.print() - - post_node_list = start_node.post_nodes - for node in post_node_list: - _visit_graph(node) - - -def _get_tensors(): - for obj in gc.get_objects(): - try: - if torch.is_tensor(obj): - yield obj - except Exception as e: - print('A trivial exception occured: {}'.format(e)) - - -def _count_tensors(): - cnt = 0 - for t in _get_tensors(): - cnt += 1 - return cnt - - -def count_tensors(use_colossal): - model = SimpleNet() - - model.eval() - with torch.no_grad(): - if use_colossal: - colo_input = ColoTensor.from_torch_tensor(torch.randn(4)) - graph_ctx = GraphContext() - with graph_ctx: - output = model(colo_input) - output = model(colo_input) - ret = _count_tensors() - - _visit_graph(graph_ctx.graph_nodes[0]) - - del graph_ctx - return ret - else: - input_t = torch.randn(4) - output = model(input_t) - output = model(input_t) - return _count_tensors() - - -@pytest.mark.skip -# FIXME(ver217) -def test_check_activation_tensors(): - assert count_tensors(False) == count_tensors(True) - - -if __name__ == "__main__": - count_tensors(True)