[graph] improve the graph building. (#1157)

pull/1160/head
Jiarui Fang 2022-06-22 16:47:20 +08:00 committed by GitHub
parent 22717a856f
commit 07f9c781f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 79 additions and 103 deletions

View File

@ -1,14 +1,14 @@
import torch.nn.functional as F
from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
from colossalai.context import ParallelMode
from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
@ -28,7 +28,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
return output
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
@ -48,23 +48,21 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
return output
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
assert mode in ('row', 'col')
funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol}
return funcs[mode](input_tensor, weight, bias)
@colo_op_impl(F.linear)
def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None):
@register_colo_graph(input_pos=[1], param_pos=[2, 3])
def colo_linear_imp(input_tensor: GeneralTensor,
weight: GeneralTensor,
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
# building the computing graph, inputs -> op
if GraphGlobalEnv().graph_building:
cur_op_node = GraphOpNode('linear', [weight, bias])
cur_op_node.add_prev_tensor(input_tensor)
# Add communication logic before and after linear call.
ret_tensor = None
if not weight.has_spec(): # No Model Parallel Applied
@ -82,7 +80,11 @@ def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Option
else:
raise NotImplementedError
# building the computing graph, op -> output
if GraphGlobalEnv().graph_building:
cur_op_node.add_post_tensor(ret_tensor)
return ret_tensor
@colo_op_impl(F.linear)
def colo_linear(input_tensor: GeneralTensor,
weight: GeneralTensor,
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
return colo_linear_imp(input_tensor, weight, bias)

View File

@ -0,0 +1,4 @@
from .utils import register_colo_graph
from .graph_node import GraphContext, GraphGlobalEnv, GraphOpNode
__all__ = ['register_colo_graph', 'GraphContext', 'GraphGlobalEnv', 'GraphOpNode']

View File

@ -74,7 +74,6 @@ class GraphOpNode(GraphNode):
assert isinstance(colo_tensor, ColoTensor)
if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode()
prev_ops = colo_tensor._graph_node.prev_nodes
for op_node in prev_ops:
self.add_prev_node(op_node)
@ -85,7 +84,7 @@ class GraphOpNode(GraphNode):
Op <- Activation (colo_tensor)
"""
if GraphGlobalEnv().graph_building:
assert isinstance(colo_tensor, ColoTensor)
assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}'
if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode()

View File

@ -0,0 +1,50 @@
import functools
import torch
from colossalai.tensor import ColoTensor
from typing import Callable, List
from colossalai.nn._ops._utils import convert_to_colo_tensor
def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
"""register_colo_graph
Register a Op (Layer) to ColoGraph.
Recoders the input args in types of ColoTensor to the Graph.
Args:
func (Callable): a function implements the Op.
Returns:
Callable: wrapper function.
"""
def register_colo_graph_decorator(func):
from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv
@functools.wraps(func)
def wrapper(*args, **kwargs):
param_list = []
input_list = []
for idx, arg in enumerate(args):
if isinstance(arg, torch.Tensor) and idx in input_pos:
input_list.append(convert_to_colo_tensor(arg))
if isinstance(arg, torch.Tensor) and idx in param_pos:
param_list.append(convert_to_colo_tensor(arg))
print(f'Op {func}')
# building the computing graph, inputs -> op
if GraphGlobalEnv().graph_building:
cur_op_node = GraphOpNode('linear', param_list)
# TODO supports a list of ColoTensor as args
if len(input_list) > 0:
cur_op_node.add_prev_tensor(input_list[0])
outputs = func(*args, **kwargs)
# building the computing graph, op -> output
if GraphGlobalEnv().graph_building:
# TODO supports a list of ColoTensor as args
if isinstance(outputs[0], ColoTensor):
cur_op_node.add_post_tensor(outputs[0])
return outputs
return wrapper
return register_colo_graph_decorator

View File

@ -17,6 +17,13 @@ class _DistSpec:
dist_placement_pattern: DistPlacementPattern,
process_group: Optional[ProcessGroup] = None,
**meta_info):
"""_DistSpec, Distributed Specification
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
self.placement = dist_placement_pattern
self.process_group = process_group
for k, v in meta_info.items():
@ -37,6 +44,7 @@ class _DistSpec:
res += f'{attr}: {str(getattr(self, attr))}\n\t'
return res
def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
# process_group=None means global process group
return _DistSpec(DistPlacementPattern.REPLICATE, process_group)

View File

@ -1,3 +0,0 @@
from .graph_node import GraphNode, GraphOpNode, GraphContext, GraphGlobalEnv
__all__ = ['GraphNode', 'GraphOpNode', 'GraphContext', 'GraphGlobalEnv']

View File

@ -1,84 +0,0 @@
import pytest
from torch import nn
import torch
from colossalai.tensor import ColoTensor
from colossalai.tensor.graph import GraphContext
import gc
class SimpleNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.proj1 = nn.Linear(4, 8)
self.proj2 = nn.Linear(8, 4)
self.proj3 = nn.Linear(4, 4)
self.proj4 = nn.Linear(4, 4)
def forward(self, x):
x = self.proj1(x)
x = self.proj2(x)
x = self.proj3(x)
x = self.proj4(x)
return x
def _visit_graph(start_node):
if start_node is None:
return
start_node.print()
post_node_list = start_node.post_nodes
for node in post_node_list:
_visit_graph(node)
def _get_tensors():
for obj in gc.get_objects():
try:
if torch.is_tensor(obj):
yield obj
except Exception as e:
print('A trivial exception occured: {}'.format(e))
def _count_tensors():
cnt = 0
for t in _get_tensors():
cnt += 1
return cnt
def count_tensors(use_colossal):
model = SimpleNet()
model.eval()
with torch.no_grad():
if use_colossal:
colo_input = ColoTensor.from_torch_tensor(torch.randn(4))
graph_ctx = GraphContext()
with graph_ctx:
output = model(colo_input)
output = model(colo_input)
ret = _count_tensors()
_visit_graph(graph_ctx.graph_nodes[0])
del graph_ctx
return ret
else:
input_t = torch.randn(4)
output = model(input_t)
output = model(input_t)
return _count_tensors()
@pytest.mark.skip
# FIXME(ver217)
def test_check_activation_tensors():
assert count_tensors(False) == count_tensors(True)
if __name__ == "__main__":
count_tensors(True)