mirror of https://github.com/hpcaitech/ColossalAI
[graph] improve the graph building. (#1157)
parent
22717a856f
commit
07f9c781f9
|
@ -1,14 +1,14 @@
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
|
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
|
||||||
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
|
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
|
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||||
# Input:S[1] x Weight:S[0] = Output:P
|
# Input:S[1] x Weight:S[0] = Output:P
|
||||||
# All-Reduce(Output) + bias = res
|
# All-Reduce(Output) + bias = res
|
||||||
# Input:S[1]
|
# Input:S[1]
|
||||||
|
@ -28,7 +28,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
|
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||||
# All-Gather(Output)
|
# All-Gather(Output)
|
||||||
# Input:B
|
# Input:B
|
||||||
|
@ -48,23 +48,21 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
|
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||||
assert mode in ('row', 'col')
|
assert mode in ('row', 'col')
|
||||||
funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol}
|
funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol}
|
||||||
return funcs[mode](input_tensor, weight, bias)
|
return funcs[mode](input_tensor, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(F.linear)
|
@register_colo_graph(input_pos=[1], param_pos=[2, 3])
|
||||||
def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None):
|
def colo_linear_imp(input_tensor: GeneralTensor,
|
||||||
|
weight: GeneralTensor,
|
||||||
|
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
||||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
||||||
This method computes a linear.
|
This method computes a linear.
|
||||||
"""
|
"""
|
||||||
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
||||||
|
|
||||||
# building the computing graph, inputs -> op
|
|
||||||
if GraphGlobalEnv().graph_building:
|
|
||||||
cur_op_node = GraphOpNode('linear', [weight, bias])
|
|
||||||
cur_op_node.add_prev_tensor(input_tensor)
|
|
||||||
# Add communication logic before and after linear call.
|
# Add communication logic before and after linear call.
|
||||||
ret_tensor = None
|
ret_tensor = None
|
||||||
if not weight.has_spec(): # No Model Parallel Applied
|
if not weight.has_spec(): # No Model Parallel Applied
|
||||||
|
@ -82,7 +80,11 @@ def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Option
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# building the computing graph, op -> output
|
|
||||||
if GraphGlobalEnv().graph_building:
|
|
||||||
cur_op_node.add_post_tensor(ret_tensor)
|
|
||||||
return ret_tensor
|
return ret_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@colo_op_impl(F.linear)
|
||||||
|
def colo_linear(input_tensor: GeneralTensor,
|
||||||
|
weight: GeneralTensor,
|
||||||
|
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
||||||
|
return colo_linear_imp(input_tensor, weight, bias)
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .utils import register_colo_graph
|
||||||
|
from .graph_node import GraphContext, GraphGlobalEnv, GraphOpNode
|
||||||
|
|
||||||
|
__all__ = ['register_colo_graph', 'GraphContext', 'GraphGlobalEnv', 'GraphOpNode']
|
|
@ -74,7 +74,6 @@ class GraphOpNode(GraphNode):
|
||||||
assert isinstance(colo_tensor, ColoTensor)
|
assert isinstance(colo_tensor, ColoTensor)
|
||||||
if colo_tensor._graph_node is None:
|
if colo_tensor._graph_node is None:
|
||||||
colo_tensor._graph_node = GraphNode()
|
colo_tensor._graph_node = GraphNode()
|
||||||
|
|
||||||
prev_ops = colo_tensor._graph_node.prev_nodes
|
prev_ops = colo_tensor._graph_node.prev_nodes
|
||||||
for op_node in prev_ops:
|
for op_node in prev_ops:
|
||||||
self.add_prev_node(op_node)
|
self.add_prev_node(op_node)
|
||||||
|
@ -85,7 +84,7 @@ class GraphOpNode(GraphNode):
|
||||||
Op <- Activation (colo_tensor)
|
Op <- Activation (colo_tensor)
|
||||||
"""
|
"""
|
||||||
if GraphGlobalEnv().graph_building:
|
if GraphGlobalEnv().graph_building:
|
||||||
assert isinstance(colo_tensor, ColoTensor)
|
assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}'
|
||||||
if colo_tensor._graph_node is None:
|
if colo_tensor._graph_node is None:
|
||||||
colo_tensor._graph_node = GraphNode()
|
colo_tensor._graph_node = GraphNode()
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
import functools
|
||||||
|
import torch
|
||||||
|
from colossalai.tensor import ColoTensor
|
||||||
|
from typing import Callable, List
|
||||||
|
from colossalai.nn._ops._utils import convert_to_colo_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
|
||||||
|
"""register_colo_graph
|
||||||
|
Register a Op (Layer) to ColoGraph.
|
||||||
|
Recoders the input args in types of ColoTensor to the Graph.
|
||||||
|
Args:
|
||||||
|
func (Callable): a function implements the Op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: wrapper function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def register_colo_graph_decorator(func):
|
||||||
|
from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
param_list = []
|
||||||
|
input_list = []
|
||||||
|
for idx, arg in enumerate(args):
|
||||||
|
if isinstance(arg, torch.Tensor) and idx in input_pos:
|
||||||
|
input_list.append(convert_to_colo_tensor(arg))
|
||||||
|
if isinstance(arg, torch.Tensor) and idx in param_pos:
|
||||||
|
param_list.append(convert_to_colo_tensor(arg))
|
||||||
|
print(f'Op {func}')
|
||||||
|
# building the computing graph, inputs -> op
|
||||||
|
if GraphGlobalEnv().graph_building:
|
||||||
|
cur_op_node = GraphOpNode('linear', param_list)
|
||||||
|
# TODO supports a list of ColoTensor as args
|
||||||
|
if len(input_list) > 0:
|
||||||
|
cur_op_node.add_prev_tensor(input_list[0])
|
||||||
|
|
||||||
|
outputs = func(*args, **kwargs)
|
||||||
|
|
||||||
|
# building the computing graph, op -> output
|
||||||
|
if GraphGlobalEnv().graph_building:
|
||||||
|
# TODO supports a list of ColoTensor as args
|
||||||
|
if isinstance(outputs[0], ColoTensor):
|
||||||
|
cur_op_node.add_post_tensor(outputs[0])
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return register_colo_graph_decorator
|
|
@ -17,6 +17,13 @@ class _DistSpec:
|
||||||
dist_placement_pattern: DistPlacementPattern,
|
dist_placement_pattern: DistPlacementPattern,
|
||||||
process_group: Optional[ProcessGroup] = None,
|
process_group: Optional[ProcessGroup] = None,
|
||||||
**meta_info):
|
**meta_info):
|
||||||
|
"""_DistSpec, Distributed Specification
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
|
||||||
|
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
|
||||||
|
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
|
||||||
|
"""
|
||||||
self.placement = dist_placement_pattern
|
self.placement = dist_placement_pattern
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
for k, v in meta_info.items():
|
for k, v in meta_info.items():
|
||||||
|
@ -37,6 +44,7 @@ class _DistSpec:
|
||||||
res += f'{attr}: {str(getattr(self, attr))}\n\t'
|
res += f'{attr}: {str(getattr(self, attr))}\n\t'
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
|
def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
|
||||||
# process_group=None means global process group
|
# process_group=None means global process group
|
||||||
return _DistSpec(DistPlacementPattern.REPLICATE, process_group)
|
return _DistSpec(DistPlacementPattern.REPLICATE, process_group)
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
from .graph_node import GraphNode, GraphOpNode, GraphContext, GraphGlobalEnv
|
|
||||||
|
|
||||||
__all__ = ['GraphNode', 'GraphOpNode', 'GraphContext', 'GraphGlobalEnv']
|
|
|
@ -1,84 +0,0 @@
|
||||||
import pytest
|
|
||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
from colossalai.tensor import ColoTensor
|
|
||||||
from colossalai.tensor.graph import GraphContext
|
|
||||||
import gc
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleNet(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.proj1 = nn.Linear(4, 8)
|
|
||||||
self.proj2 = nn.Linear(8, 4)
|
|
||||||
self.proj3 = nn.Linear(4, 4)
|
|
||||||
self.proj4 = nn.Linear(4, 4)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.proj1(x)
|
|
||||||
x = self.proj2(x)
|
|
||||||
x = self.proj3(x)
|
|
||||||
x = self.proj4(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _visit_graph(start_node):
|
|
||||||
if start_node is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
start_node.print()
|
|
||||||
|
|
||||||
post_node_list = start_node.post_nodes
|
|
||||||
for node in post_node_list:
|
|
||||||
_visit_graph(node)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tensors():
|
|
||||||
for obj in gc.get_objects():
|
|
||||||
try:
|
|
||||||
if torch.is_tensor(obj):
|
|
||||||
yield obj
|
|
||||||
except Exception as e:
|
|
||||||
print('A trivial exception occured: {}'.format(e))
|
|
||||||
|
|
||||||
|
|
||||||
def _count_tensors():
|
|
||||||
cnt = 0
|
|
||||||
for t in _get_tensors():
|
|
||||||
cnt += 1
|
|
||||||
return cnt
|
|
||||||
|
|
||||||
|
|
||||||
def count_tensors(use_colossal):
|
|
||||||
model = SimpleNet()
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
if use_colossal:
|
|
||||||
colo_input = ColoTensor.from_torch_tensor(torch.randn(4))
|
|
||||||
graph_ctx = GraphContext()
|
|
||||||
with graph_ctx:
|
|
||||||
output = model(colo_input)
|
|
||||||
output = model(colo_input)
|
|
||||||
ret = _count_tensors()
|
|
||||||
|
|
||||||
_visit_graph(graph_ctx.graph_nodes[0])
|
|
||||||
|
|
||||||
del graph_ctx
|
|
||||||
return ret
|
|
||||||
else:
|
|
||||||
input_t = torch.randn(4)
|
|
||||||
output = model(input_t)
|
|
||||||
output = model(input_t)
|
|
||||||
return _count_tensors()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
# FIXME(ver217)
|
|
||||||
def test_check_activation_tensors():
|
|
||||||
assert count_tensors(False) == count_tensors(True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
count_tensors(True)
|
|
Loading…
Reference in New Issue