[graph] improve the graph building. (#1157)

pull/1160/head
Jiarui Fang 2022-06-22 16:47:20 +08:00 committed by GitHub
parent 22717a856f
commit 07f9c781f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 79 additions and 103 deletions

View File

@ -1,14 +1,14 @@
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from ._utils import GeneralTensor, convert_to_colo_tensor from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
# Input:S[1] x Weight:S[0] = Output:P # Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res # All-Reduce(Output) + bias = res
# Input:S[1] # Input:S[1]
@ -28,7 +28,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
return output return output
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output) # All-Gather(Output)
# Input:B # Input:B
@ -48,23 +48,21 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
return output return output
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
assert mode in ('row', 'col') assert mode in ('row', 'col')
funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol} funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol}
return funcs[mode](input_tensor, weight, bias) return funcs[mode](input_tensor, weight, bias)
@colo_op_impl(F.linear) @register_colo_graph(input_pos=[1], param_pos=[2, 3])
def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None): def colo_linear_imp(input_tensor: GeneralTensor,
weight: GeneralTensor,
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear. This method computes a linear.
""" """
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
# building the computing graph, inputs -> op
if GraphGlobalEnv().graph_building:
cur_op_node = GraphOpNode('linear', [weight, bias])
cur_op_node.add_prev_tensor(input_tensor)
# Add communication logic before and after linear call. # Add communication logic before and after linear call.
ret_tensor = None ret_tensor = None
if not weight.has_spec(): # No Model Parallel Applied if not weight.has_spec(): # No Model Parallel Applied
@ -82,7 +80,11 @@ def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Option
else: else:
raise NotImplementedError raise NotImplementedError
# building the computing graph, op -> output
if GraphGlobalEnv().graph_building:
cur_op_node.add_post_tensor(ret_tensor)
return ret_tensor return ret_tensor
@colo_op_impl(F.linear)
def colo_linear(input_tensor: GeneralTensor,
weight: GeneralTensor,
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
return colo_linear_imp(input_tensor, weight, bias)

View File

@ -0,0 +1,4 @@
from .utils import register_colo_graph
from .graph_node import GraphContext, GraphGlobalEnv, GraphOpNode
__all__ = ['register_colo_graph', 'GraphContext', 'GraphGlobalEnv', 'GraphOpNode']

View File

@ -74,7 +74,6 @@ class GraphOpNode(GraphNode):
assert isinstance(colo_tensor, ColoTensor) assert isinstance(colo_tensor, ColoTensor)
if colo_tensor._graph_node is None: if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode() colo_tensor._graph_node = GraphNode()
prev_ops = colo_tensor._graph_node.prev_nodes prev_ops = colo_tensor._graph_node.prev_nodes
for op_node in prev_ops: for op_node in prev_ops:
self.add_prev_node(op_node) self.add_prev_node(op_node)
@ -85,7 +84,7 @@ class GraphOpNode(GraphNode):
Op <- Activation (colo_tensor) Op <- Activation (colo_tensor)
""" """
if GraphGlobalEnv().graph_building: if GraphGlobalEnv().graph_building:
assert isinstance(colo_tensor, ColoTensor) assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}'
if colo_tensor._graph_node is None: if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode() colo_tensor._graph_node = GraphNode()

View File

@ -0,0 +1,50 @@
import functools
import torch
from colossalai.tensor import ColoTensor
from typing import Callable, List
from colossalai.nn._ops._utils import convert_to_colo_tensor
def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
"""register_colo_graph
Register a Op (Layer) to ColoGraph.
Recoders the input args in types of ColoTensor to the Graph.
Args:
func (Callable): a function implements the Op.
Returns:
Callable: wrapper function.
"""
def register_colo_graph_decorator(func):
from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv
@functools.wraps(func)
def wrapper(*args, **kwargs):
param_list = []
input_list = []
for idx, arg in enumerate(args):
if isinstance(arg, torch.Tensor) and idx in input_pos:
input_list.append(convert_to_colo_tensor(arg))
if isinstance(arg, torch.Tensor) and idx in param_pos:
param_list.append(convert_to_colo_tensor(arg))
print(f'Op {func}')
# building the computing graph, inputs -> op
if GraphGlobalEnv().graph_building:
cur_op_node = GraphOpNode('linear', param_list)
# TODO supports a list of ColoTensor as args
if len(input_list) > 0:
cur_op_node.add_prev_tensor(input_list[0])
outputs = func(*args, **kwargs)
# building the computing graph, op -> output
if GraphGlobalEnv().graph_building:
# TODO supports a list of ColoTensor as args
if isinstance(outputs[0], ColoTensor):
cur_op_node.add_post_tensor(outputs[0])
return outputs
return wrapper
return register_colo_graph_decorator

View File

@ -17,6 +17,13 @@ class _DistSpec:
dist_placement_pattern: DistPlacementPattern, dist_placement_pattern: DistPlacementPattern,
process_group: Optional[ProcessGroup] = None, process_group: Optional[ProcessGroup] = None,
**meta_info): **meta_info):
"""_DistSpec, Distributed Specification
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
self.placement = dist_placement_pattern self.placement = dist_placement_pattern
self.process_group = process_group self.process_group = process_group
for k, v in meta_info.items(): for k, v in meta_info.items():
@ -37,6 +44,7 @@ class _DistSpec:
res += f'{attr}: {str(getattr(self, attr))}\n\t' res += f'{attr}: {str(getattr(self, attr))}\n\t'
return res return res
def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec: def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
# process_group=None means global process group # process_group=None means global process group
return _DistSpec(DistPlacementPattern.REPLICATE, process_group) return _DistSpec(DistPlacementPattern.REPLICATE, process_group)

View File

@ -1,3 +0,0 @@
from .graph_node import GraphNode, GraphOpNode, GraphContext, GraphGlobalEnv
__all__ = ['GraphNode', 'GraphOpNode', 'GraphContext', 'GraphGlobalEnv']

View File

@ -1,84 +0,0 @@
import pytest
from torch import nn
import torch
from colossalai.tensor import ColoTensor
from colossalai.tensor.graph import GraphContext
import gc
class SimpleNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.proj1 = nn.Linear(4, 8)
self.proj2 = nn.Linear(8, 4)
self.proj3 = nn.Linear(4, 4)
self.proj4 = nn.Linear(4, 4)
def forward(self, x):
x = self.proj1(x)
x = self.proj2(x)
x = self.proj3(x)
x = self.proj4(x)
return x
def _visit_graph(start_node):
if start_node is None:
return
start_node.print()
post_node_list = start_node.post_nodes
for node in post_node_list:
_visit_graph(node)
def _get_tensors():
for obj in gc.get_objects():
try:
if torch.is_tensor(obj):
yield obj
except Exception as e:
print('A trivial exception occured: {}'.format(e))
def _count_tensors():
cnt = 0
for t in _get_tensors():
cnt += 1
return cnt
def count_tensors(use_colossal):
model = SimpleNet()
model.eval()
with torch.no_grad():
if use_colossal:
colo_input = ColoTensor.from_torch_tensor(torch.randn(4))
graph_ctx = GraphContext()
with graph_ctx:
output = model(colo_input)
output = model(colo_input)
ret = _count_tensors()
_visit_graph(graph_ctx.graph_nodes[0])
del graph_ctx
return ret
else:
input_t = torch.randn(4)
output = model(input_t)
output = model(input_t)
return _count_tensors()
@pytest.mark.skip
# FIXME(ver217)
def test_check_activation_tensors():
assert count_tensors(False) == count_tensors(True)
if __name__ == "__main__":
count_tensors(True)