mirror of https://github.com/hpcaitech/ColossalAI
[graph] improve the graph building. (#1157)
parent
22717a856f
commit
07f9c781f9
|
@ -1,14 +1,14 @@
|
|||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
|
||||
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
|
||||
from colossalai.context import ParallelMode
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
|
||||
|
||||
|
||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
|
||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||
# Input:S[1] x Weight:S[0] = Output:P
|
||||
# All-Reduce(Output) + bias = res
|
||||
# Input:S[1]
|
||||
|
@ -28,7 +28,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
return output
|
||||
|
||||
|
||||
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
|
||||
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||
# All-Gather(Output)
|
||||
# Input:B
|
||||
|
@ -48,23 +48,21 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
return output
|
||||
|
||||
|
||||
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
|
||||
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||
assert mode in ('row', 'col')
|
||||
funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol}
|
||||
return funcs[mode](input_tensor, weight, bias)
|
||||
|
||||
|
||||
@colo_op_impl(F.linear)
|
||||
def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None):
|
||||
@register_colo_graph(input_pos=[1], param_pos=[2, 3])
|
||||
def colo_linear_imp(input_tensor: GeneralTensor,
|
||||
weight: GeneralTensor,
|
||||
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
||||
This method computes a linear.
|
||||
"""
|
||||
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
||||
|
||||
# building the computing graph, inputs -> op
|
||||
if GraphGlobalEnv().graph_building:
|
||||
cur_op_node = GraphOpNode('linear', [weight, bias])
|
||||
cur_op_node.add_prev_tensor(input_tensor)
|
||||
# Add communication logic before and after linear call.
|
||||
ret_tensor = None
|
||||
if not weight.has_spec(): # No Model Parallel Applied
|
||||
|
@ -82,7 +80,11 @@ def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Option
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# building the computing graph, op -> output
|
||||
if GraphGlobalEnv().graph_building:
|
||||
cur_op_node.add_post_tensor(ret_tensor)
|
||||
return ret_tensor
|
||||
|
||||
|
||||
@colo_op_impl(F.linear)
|
||||
def colo_linear(input_tensor: GeneralTensor,
|
||||
weight: GeneralTensor,
|
||||
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
||||
return colo_linear_imp(input_tensor, weight, bias)
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .utils import register_colo_graph
|
||||
from .graph_node import GraphContext, GraphGlobalEnv, GraphOpNode
|
||||
|
||||
__all__ = ['register_colo_graph', 'GraphContext', 'GraphGlobalEnv', 'GraphOpNode']
|
|
@ -74,7 +74,6 @@ class GraphOpNode(GraphNode):
|
|||
assert isinstance(colo_tensor, ColoTensor)
|
||||
if colo_tensor._graph_node is None:
|
||||
colo_tensor._graph_node = GraphNode()
|
||||
|
||||
prev_ops = colo_tensor._graph_node.prev_nodes
|
||||
for op_node in prev_ops:
|
||||
self.add_prev_node(op_node)
|
||||
|
@ -85,7 +84,7 @@ class GraphOpNode(GraphNode):
|
|||
Op <- Activation (colo_tensor)
|
||||
"""
|
||||
if GraphGlobalEnv().graph_building:
|
||||
assert isinstance(colo_tensor, ColoTensor)
|
||||
assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}'
|
||||
if colo_tensor._graph_node is None:
|
||||
colo_tensor._graph_node = GraphNode()
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
import functools
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from typing import Callable, List
|
||||
from colossalai.nn._ops._utils import convert_to_colo_tensor
|
||||
|
||||
|
||||
def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
|
||||
"""register_colo_graph
|
||||
Register a Op (Layer) to ColoGraph.
|
||||
Recoders the input args in types of ColoTensor to the Graph.
|
||||
Args:
|
||||
func (Callable): a function implements the Op.
|
||||
|
||||
Returns:
|
||||
Callable: wrapper function.
|
||||
"""
|
||||
|
||||
def register_colo_graph_decorator(func):
|
||||
from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
param_list = []
|
||||
input_list = []
|
||||
for idx, arg in enumerate(args):
|
||||
if isinstance(arg, torch.Tensor) and idx in input_pos:
|
||||
input_list.append(convert_to_colo_tensor(arg))
|
||||
if isinstance(arg, torch.Tensor) and idx in param_pos:
|
||||
param_list.append(convert_to_colo_tensor(arg))
|
||||
print(f'Op {func}')
|
||||
# building the computing graph, inputs -> op
|
||||
if GraphGlobalEnv().graph_building:
|
||||
cur_op_node = GraphOpNode('linear', param_list)
|
||||
# TODO supports a list of ColoTensor as args
|
||||
if len(input_list) > 0:
|
||||
cur_op_node.add_prev_tensor(input_list[0])
|
||||
|
||||
outputs = func(*args, **kwargs)
|
||||
|
||||
# building the computing graph, op -> output
|
||||
if GraphGlobalEnv().graph_building:
|
||||
# TODO supports a list of ColoTensor as args
|
||||
if isinstance(outputs[0], ColoTensor):
|
||||
cur_op_node.add_post_tensor(outputs[0])
|
||||
return outputs
|
||||
|
||||
return wrapper
|
||||
|
||||
return register_colo_graph_decorator
|
|
@ -17,6 +17,13 @@ class _DistSpec:
|
|||
dist_placement_pattern: DistPlacementPattern,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**meta_info):
|
||||
"""_DistSpec, Distributed Specification
|
||||
|
||||
Args:
|
||||
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
|
||||
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
|
||||
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
|
||||
"""
|
||||
self.placement = dist_placement_pattern
|
||||
self.process_group = process_group
|
||||
for k, v in meta_info.items():
|
||||
|
@ -37,6 +44,7 @@ class _DistSpec:
|
|||
res += f'{attr}: {str(getattr(self, attr))}\n\t'
|
||||
return res
|
||||
|
||||
|
||||
def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
|
||||
# process_group=None means global process group
|
||||
return _DistSpec(DistPlacementPattern.REPLICATE, process_group)
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
from .graph_node import GraphNode, GraphOpNode, GraphContext, GraphGlobalEnv
|
||||
|
||||
__all__ = ['GraphNode', 'GraphOpNode', 'GraphContext', 'GraphGlobalEnv']
|
|
@ -1,84 +0,0 @@
|
|||
import pytest
|
||||
from torch import nn
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.tensor.graph import GraphContext
|
||||
import gc
|
||||
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.proj1 = nn.Linear(4, 8)
|
||||
self.proj2 = nn.Linear(8, 4)
|
||||
self.proj3 = nn.Linear(4, 4)
|
||||
self.proj4 = nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
x = self.proj2(x)
|
||||
x = self.proj3(x)
|
||||
x = self.proj4(x)
|
||||
return x
|
||||
|
||||
|
||||
def _visit_graph(start_node):
|
||||
if start_node is None:
|
||||
return
|
||||
|
||||
start_node.print()
|
||||
|
||||
post_node_list = start_node.post_nodes
|
||||
for node in post_node_list:
|
||||
_visit_graph(node)
|
||||
|
||||
|
||||
def _get_tensors():
|
||||
for obj in gc.get_objects():
|
||||
try:
|
||||
if torch.is_tensor(obj):
|
||||
yield obj
|
||||
except Exception as e:
|
||||
print('A trivial exception occured: {}'.format(e))
|
||||
|
||||
|
||||
def _count_tensors():
|
||||
cnt = 0
|
||||
for t in _get_tensors():
|
||||
cnt += 1
|
||||
return cnt
|
||||
|
||||
|
||||
def count_tensors(use_colossal):
|
||||
model = SimpleNet()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
if use_colossal:
|
||||
colo_input = ColoTensor.from_torch_tensor(torch.randn(4))
|
||||
graph_ctx = GraphContext()
|
||||
with graph_ctx:
|
||||
output = model(colo_input)
|
||||
output = model(colo_input)
|
||||
ret = _count_tensors()
|
||||
|
||||
_visit_graph(graph_ctx.graph_nodes[0])
|
||||
|
||||
del graph_ctx
|
||||
return ret
|
||||
else:
|
||||
input_t = torch.randn(4)
|
||||
output = model(input_t)
|
||||
output = model(input_t)
|
||||
return _count_tensors()
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
# FIXME(ver217)
|
||||
def test_check_activation_tensors():
|
||||
assert count_tensors(False) == count_tensors(True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
count_tensors(True)
|
Loading…
Reference in New Issue