diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index 5e3f4934b..9b4225ecc 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -1,15 +1,14 @@ import torch from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.context import ParallelMode -from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, \ - gather_forward_split_backward, reduce_grad +from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, reduce_grad from colossalai.nn.layer.utils import divide from colossalai.core import global_context as gpc from packaging import version from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern +from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv -def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTensor) -> ColoTensor: +def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor: parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_Linear) # Input:S[1] x Weight:S[0] = Output:P # All-Reduce(Output) + bias = res @@ -99,20 +98,32 @@ def colo_linear(types, args, kwargs, pg): if bias is not None and not isinstance(bias, ColoTensor): bias = ColoTensor.init_from_torch_tensor(bias) + # building the computing graph, inputs -> op + if GraphGlobalEnv().graph_building: + cur_op_node = GraphOpNode('linear', [weight, bias]) + cur_op_node.add_prev_tensor(input_tensor) + # Add communication logic before and after linear call. + ret_tensor = None if not weight.has_spec(): # No Model Parallel Applied assert not bias.has_spec(), 'Invalid bias spec for native Linear op' input_tensor = input_tensor.torch_tensor() weight = weight.torch_tensor() bias = bias.torch_tensor() - return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias)) + ret_tensor = ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias)) elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied compute_patterns = weight.shard_spec.compute_patterns if ComputePattern.TP1DRow_Linear in compute_patterns: - return colo_linear_1Drow(input_tensor, weight, bias) + ret_tensor = colo_linear_1Drow(input_tensor, weight, bias) elif ComputePattern.TP1DCol_Linear in compute_patterns: - return colo_linear_1Dcol(input_tensor, weight, bias) + ret_tensor = colo_linear_1Dcol(input_tensor, weight, bias) else: raise NotImplementedError else: raise NotImplementedError + + # building the computing graph, op -> output + if GraphGlobalEnv().graph_building: + cur_op_node.add_post_tensor(ret_tensor) + + return ret_tensor diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 5a7c06d64..7dc1e78f7 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -38,6 +38,7 @@ class ColoTensor(object): self._shard_spec = shard_spec self._shard_pattern = ShardPattern.NA self._type = TensorType.NONMODEL + self._graph_node = None def __getitem__(self, key): return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key]) diff --git a/colossalai/tensor/graph/__init__.py b/colossalai/tensor/graph/__init__.py new file mode 100644 index 000000000..b0788062f --- /dev/null +++ b/colossalai/tensor/graph/__init__.py @@ -0,0 +1,3 @@ +from .graph_node import GraphNode, GraphOpNode, GraphContext, GraphGlobalEnv + +__all__ = ['GraphNode', 'GraphOpNode', 'GraphContext', 'GraphGlobalEnv'] diff --git a/colossalai/tensor/graph/graph_node.py b/colossalai/tensor/graph/graph_node.py new file mode 100644 index 000000000..d637d2bed --- /dev/null +++ b/colossalai/tensor/graph/graph_node.py @@ -0,0 +1,97 @@ +from colossalai.tensor import ColoTensor +from colossalai.context.singleton_meta import SingletonMeta + + +class GraphGlobalEnv(metaclass=SingletonMeta): + + def __init__(self) -> None: + self.graph_building = False + self.graph_node_list = [] + self.node_id = -1 + + def get_node_id(self): + self.node_id += 1 + return self.node_id + + def add_graph_node(self, node): + self.graph_node_list.append(node) + + +class GraphContext(): + """ + + Building the computing graph under the context + + >>> with GraphContext(): + >>> output = model(colo_input_tensor) + """ + graph_nodes = [] + + def __enter__(self): + GraphGlobalEnv().graph_building = True + GraphGlobalEnv().graph_node_list = [] + + def __exit__(self, *exc_info): + GraphGlobalEnv().graph_building = False + GraphGlobalEnv().node_id = -1 + self.graph_nodes = GraphGlobalEnv().graph_node_list + + +class GraphNode(object): + + def __init__(self) -> None: + self.prev_nodes = [] + self.post_nodes = [] + self.id = GraphGlobalEnv().get_node_id() + + def add_prev_node(self, node): + if GraphGlobalEnv().graph_building: + self.prev_nodes.append(node) + + def add_post_node(self, node): + if GraphGlobalEnv().graph_building: + self.post_nodes.append(node) + + def post_node_empty(self) -> bool: + return len(self.post_nodes) == 0 + + +class GraphOpNode(GraphNode): + + def __init__(self, op_type, param_list) -> None: + super().__init__() + self._op_type = op_type + self._param_list = param_list + GraphGlobalEnv().add_graph_node(self) + + def add_prev_tensor(self, colo_tensor: ColoTensor): + r""" + Link the current graph op node to previous graph op. + Op1 <- Activation (colo_tensor) Op2 + Op1 <- Op2 + """ + if GraphGlobalEnv().graph_building: + assert isinstance(colo_tensor, ColoTensor) + if colo_tensor._graph_node is None: + colo_tensor._graph_node = GraphNode() + + prev_ops = colo_tensor._graph_node.prev_nodes + for op_node in prev_ops: + self.add_prev_node(op_node) + op_node.add_post_node(self) + + def add_post_tensor(self, colo_tensor: ColoTensor): + """ + Op <- Activation (colo_tensor) + """ + if GraphGlobalEnv().graph_building: + assert isinstance(colo_tensor, ColoTensor) + if colo_tensor._graph_node is None: + colo_tensor._graph_node = GraphNode() + + colo_tensor._graph_node.add_prev_node(self) + + def print(self): + print( + f'GraphOpNode {self._op_type} {self.id}, post nodes {[node.id for node in self.post_nodes]}, prev node number {[node.id for node in self.prev_nodes]}' + ) diff --git a/tests/test_tensor/test_graph.py b/tests/test_tensor/test_graph.py new file mode 100644 index 000000000..861c74301 --- /dev/null +++ b/tests/test_tensor/test_graph.py @@ -0,0 +1,81 @@ +from torch import nn +import torch +from colossalai.tensor import ColoTensor +from colossalai.tensor.graph import GraphContext +import gc + + +class SimpleNet(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.proj1 = nn.Linear(4, 8) + self.proj2 = nn.Linear(8, 4) + self.proj3 = nn.Linear(4, 4) + self.proj4 = nn.Linear(4, 4) + + def forward(self, x): + x = self.proj1(x) + x = self.proj2(x) + x = self.proj3(x) + x = self.proj4(x) + return x + + +def _visit_graph(start_node): + if start_node is None: + return + + start_node.print() + + post_node_list = start_node.post_nodes + for node in post_node_list: + _visit_graph(node) + + +def _get_tensors(): + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj): + yield obj + except Exception as e: + print('A trivial exception occured: {}'.format(e)) + + +def _count_tensors(): + cnt = 0 + for t in _get_tensors(): + cnt += 1 + return cnt + + +def count_tensors(use_colossal): + model = SimpleNet() + + model.eval() + with torch.no_grad(): + if use_colossal: + colo_input = ColoTensor.init_from_torch_tensor(torch.randn(4)) + graph_ctx = GraphContext() + with graph_ctx: + output = model(colo_input) + output = model(colo_input) + ret = _count_tensors() + + _visit_graph(graph_ctx.graph_nodes[0]) + + del graph_ctx + return ret + else: + input_t = torch.randn(4) + output = model(input_t) + output = model(input_t) + return _count_tensors() + + +def test_check_activation_tensors(): + assert count_tensors(False) == count_tensors(True) + + +if __name__ == "__main__": + count_tensors(True) diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index f8366516e..aabf4c7f6 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -26,7 +26,7 @@ from dataclasses import fields def _post_init_colo(self): class_fields = fields(self) # Safety and consistency checks - if not len(class_fields): + if len(class_fields) == 0: raise ValueError(f"{self.__class__.__name__} has no fields.") if not all(field.default is None for field in class_fields[1:]): raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") @@ -361,7 +361,7 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') for name in ['simple_net']: run_1d_row_tp(name) - for name in ['bert', 'simple_net']: + for name in ['bert', 'simple_net']: run_1d_hybrid_tp(name)