[Graph] building computing graph with ColoTensor, Linear only (#917)

pull/920/head
Jiarui Fang 2022-05-07 17:10:37 +08:00 committed by GitHub
parent 75d221918a
commit 845856ea29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 202 additions and 9 deletions

View File

@ -1,15 +1,14 @@
import torch import torch
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.context import ParallelMode from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, reduce_grad
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, \
gather_forward_split_backward, reduce_grad
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from packaging import version from packaging import version
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTensor) -> ColoTensor: def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_Linear) parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_Linear)
# Input:S[1] x Weight:S[0] = Output:P # Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res # All-Reduce(Output) + bias = res
@ -99,20 +98,32 @@ def colo_linear(types, args, kwargs, pg):
if bias is not None and not isinstance(bias, ColoTensor): if bias is not None and not isinstance(bias, ColoTensor):
bias = ColoTensor.init_from_torch_tensor(bias) bias = ColoTensor.init_from_torch_tensor(bias)
# building the computing graph, inputs -> op
if GraphGlobalEnv().graph_building:
cur_op_node = GraphOpNode('linear', [weight, bias])
cur_op_node.add_prev_tensor(input_tensor)
# Add communication logic before and after linear call. # Add communication logic before and after linear call.
ret_tensor = None
if not weight.has_spec(): # No Model Parallel Applied if not weight.has_spec(): # No Model Parallel Applied
assert not bias.has_spec(), 'Invalid bias spec for native Linear op' assert not bias.has_spec(), 'Invalid bias spec for native Linear op'
input_tensor = input_tensor.torch_tensor() input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor() weight = weight.torch_tensor()
bias = bias.torch_tensor() bias = bias.torch_tensor()
return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias)) ret_tensor = ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied
compute_patterns = weight.shard_spec.compute_patterns compute_patterns = weight.shard_spec.compute_patterns
if ComputePattern.TP1DRow_Linear in compute_patterns: if ComputePattern.TP1DRow_Linear in compute_patterns:
return colo_linear_1Drow(input_tensor, weight, bias) ret_tensor = colo_linear_1Drow(input_tensor, weight, bias)
elif ComputePattern.TP1DCol_Linear in compute_patterns: elif ComputePattern.TP1DCol_Linear in compute_patterns:
return colo_linear_1Dcol(input_tensor, weight, bias) ret_tensor = colo_linear_1Dcol(input_tensor, weight, bias)
else: else:
raise NotImplementedError raise NotImplementedError
else: else:
raise NotImplementedError raise NotImplementedError
# building the computing graph, op -> output
if GraphGlobalEnv().graph_building:
cur_op_node.add_post_tensor(ret_tensor)
return ret_tensor

View File

@ -38,6 +38,7 @@ class ColoTensor(object):
self._shard_spec = shard_spec self._shard_spec = shard_spec
self._shard_pattern = ShardPattern.NA self._shard_pattern = ShardPattern.NA
self._type = TensorType.NONMODEL self._type = TensorType.NONMODEL
self._graph_node = None
def __getitem__(self, key): def __getitem__(self, key):
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key]) return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key])

View File

@ -0,0 +1,3 @@
from .graph_node import GraphNode, GraphOpNode, GraphContext, GraphGlobalEnv
__all__ = ['GraphNode', 'GraphOpNode', 'GraphContext', 'GraphGlobalEnv']

View File

@ -0,0 +1,97 @@
from colossalai.tensor import ColoTensor
from colossalai.context.singleton_meta import SingletonMeta
class GraphGlobalEnv(metaclass=SingletonMeta):
def __init__(self) -> None:
self.graph_building = False
self.graph_node_list = []
self.node_id = -1
def get_node_id(self):
self.node_id += 1
return self.node_id
def add_graph_node(self, node):
self.graph_node_list.append(node)
class GraphContext():
"""
Building the computing graph under the context
>>> with GraphContext():
>>> output = model(colo_input_tensor)
"""
graph_nodes = []
def __enter__(self):
GraphGlobalEnv().graph_building = True
GraphGlobalEnv().graph_node_list = []
def __exit__(self, *exc_info):
GraphGlobalEnv().graph_building = False
GraphGlobalEnv().node_id = -1
self.graph_nodes = GraphGlobalEnv().graph_node_list
class GraphNode(object):
def __init__(self) -> None:
self.prev_nodes = []
self.post_nodes = []
self.id = GraphGlobalEnv().get_node_id()
def add_prev_node(self, node):
if GraphGlobalEnv().graph_building:
self.prev_nodes.append(node)
def add_post_node(self, node):
if GraphGlobalEnv().graph_building:
self.post_nodes.append(node)
def post_node_empty(self) -> bool:
return len(self.post_nodes) == 0
class GraphOpNode(GraphNode):
def __init__(self, op_type, param_list) -> None:
super().__init__()
self._op_type = op_type
self._param_list = param_list
GraphGlobalEnv().add_graph_node(self)
def add_prev_tensor(self, colo_tensor: ColoTensor):
r"""
Link the current graph op node to previous graph op.
Op1 <- Activation (colo_tensor) Op2
Op1 <- Op2
"""
if GraphGlobalEnv().graph_building:
assert isinstance(colo_tensor, ColoTensor)
if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode()
prev_ops = colo_tensor._graph_node.prev_nodes
for op_node in prev_ops:
self.add_prev_node(op_node)
op_node.add_post_node(self)
def add_post_tensor(self, colo_tensor: ColoTensor):
"""
Op <- Activation (colo_tensor)
"""
if GraphGlobalEnv().graph_building:
assert isinstance(colo_tensor, ColoTensor)
if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode()
colo_tensor._graph_node.add_prev_node(self)
def print(self):
print(
f'GraphOpNode {self._op_type} {self.id}, post nodes {[node.id for node in self.post_nodes]}, prev node number {[node.id for node in self.prev_nodes]}'
)

View File

@ -0,0 +1,81 @@
from torch import nn
import torch
from colossalai.tensor import ColoTensor
from colossalai.tensor.graph import GraphContext
import gc
class SimpleNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.proj1 = nn.Linear(4, 8)
self.proj2 = nn.Linear(8, 4)
self.proj3 = nn.Linear(4, 4)
self.proj4 = nn.Linear(4, 4)
def forward(self, x):
x = self.proj1(x)
x = self.proj2(x)
x = self.proj3(x)
x = self.proj4(x)
return x
def _visit_graph(start_node):
if start_node is None:
return
start_node.print()
post_node_list = start_node.post_nodes
for node in post_node_list:
_visit_graph(node)
def _get_tensors():
for obj in gc.get_objects():
try:
if torch.is_tensor(obj):
yield obj
except Exception as e:
print('A trivial exception occured: {}'.format(e))
def _count_tensors():
cnt = 0
for t in _get_tensors():
cnt += 1
return cnt
def count_tensors(use_colossal):
model = SimpleNet()
model.eval()
with torch.no_grad():
if use_colossal:
colo_input = ColoTensor.init_from_torch_tensor(torch.randn(4))
graph_ctx = GraphContext()
with graph_ctx:
output = model(colo_input)
output = model(colo_input)
ret = _count_tensors()
_visit_graph(graph_ctx.graph_nodes[0])
del graph_ctx
return ret
else:
input_t = torch.randn(4)
output = model(input_t)
output = model(input_t)
return _count_tensors()
def test_check_activation_tensors():
assert count_tensors(False) == count_tensors(True)
if __name__ == "__main__":
count_tensors(True)

View File

@ -26,7 +26,7 @@ from dataclasses import fields
def _post_init_colo(self): def _post_init_colo(self):
class_fields = fields(self) class_fields = fields(self)
# Safety and consistency checks # Safety and consistency checks
if not len(class_fields): if len(class_fields) == 0:
raise ValueError(f"{self.__class__.__name__} has no fields.") raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]): if not all(field.default is None for field in class_fields[1:]):
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
@ -361,7 +361,7 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for name in ['simple_net']: for name in ['simple_net']:
run_1d_row_tp(name) run_1d_row_tp(name)
for name in ['bert', 'simple_net']: for name in ['bert', 'simple_net']:
run_1d_hybrid_tp(name) run_1d_hybrid_tp(name)