mirror of https://github.com/hpcaitech/ColossalAI
[fx]get communication size between partitions (#1224)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx]get communication size between partitions.
* polish
pull/1222/head^2
parent
4951f7d80c
commit
2b7dca44b5
|
@ -0,0 +1,101 @@
|
||||||
|
import torch
|
||||||
|
import torch.fx
|
||||||
|
from torch.fx.node import Node, map_aggregate
|
||||||
|
from typing import Any, Tuple, NamedTuple, Optional, Dict
|
||||||
|
from functools import reduce
|
||||||
|
from torch.fx._compatibility import compatibility
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
class TensorMetadata(NamedTuple):
|
||||||
|
# TensorMetadata is a structure containing pertinent information
|
||||||
|
# about a tensor within a PyTorch program.
|
||||||
|
|
||||||
|
shape: torch.Size
|
||||||
|
dtype: torch.dtype
|
||||||
|
requires_grad: bool
|
||||||
|
stride: Tuple[int]
|
||||||
|
numel: int
|
||||||
|
# TODO: we can add a list of sharding spec here, and record the sharding
|
||||||
|
# behaviour by appending sharding spec into list.
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
|
||||||
|
"""
|
||||||
|
Extract a TensorMetadata NamedTuple describing `result`.
|
||||||
|
"""
|
||||||
|
shape = result.shape
|
||||||
|
dtype = result.dtype
|
||||||
|
requires_grad = result.requires_grad
|
||||||
|
stride = result.stride()
|
||||||
|
numel = result.numel()
|
||||||
|
|
||||||
|
return TensorMetadata(shape, dtype, requires_grad, stride, numel)
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
class MetaInfoProp(torch.fx.Interpreter):
|
||||||
|
"""
|
||||||
|
Execute an FX graph Node-by-Node and
|
||||||
|
record the shape and type of the result
|
||||||
|
into the corresponding node.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
BATCH_SIZE = 2
|
||||||
|
DIM_IN = 4
|
||||||
|
DIM_OUT = 16
|
||||||
|
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||||
|
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
|
||||||
|
orig_output = model(input_sample)
|
||||||
|
gm = symbolic_trace(model)
|
||||||
|
MetaInfoProp(gm).run(input_sample)
|
||||||
|
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
print(node.name, node.meta['tensor_meta'].dtype,
|
||||||
|
node.meta['tensor_meta'].shape, node.meta['tensor_meta'].numel)
|
||||||
|
|
||||||
|
# output of above code is
|
||||||
|
# input_1 torch.float32 torch.Size([2, 4]) 8
|
||||||
|
# weight torch.float32 torch.Size([16, 4]) 64
|
||||||
|
# bias torch.float32 torch.Size([16]) 16
|
||||||
|
# linear torch.float32 torch.Size([2, 16]) 32
|
||||||
|
# output torch.float32 torch.Size([2, 16]) 32
|
||||||
|
Args:
|
||||||
|
module (GraphModule): The module to be executed
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run_node(self, n: Node) -> Any:
|
||||||
|
result = super().run_node(n)
|
||||||
|
|
||||||
|
found_tensor = False
|
||||||
|
|
||||||
|
def extract_tensor_meta(obj):
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
nonlocal found_tensor
|
||||||
|
found_tensor = True
|
||||||
|
return _extract_tensor_metadata(obj)
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
meta = map_aggregate(result, extract_tensor_meta)
|
||||||
|
if found_tensor:
|
||||||
|
n.meta['tensor_meta'] = meta
|
||||||
|
else:
|
||||||
|
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)
|
||||||
|
|
||||||
|
n.meta['type'] = type(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def propagate(self, *args):
|
||||||
|
"""
|
||||||
|
Run `module` via interpretation and return the result and
|
||||||
|
record the shape and type of each node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args (Tensor): the sample input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The value returned from executing the Module
|
||||||
|
"""
|
||||||
|
return super().run(*args)
|
|
@ -0,0 +1,27 @@
|
||||||
|
import torch
|
||||||
|
from typing import Dict, Set
|
||||||
|
from torch.fx.node import Node, map_arg
|
||||||
|
|
||||||
|
|
||||||
|
def get_comm_size(prev_partition, next_partition):
|
||||||
|
"""Given two partitions (parent and child),
|
||||||
|
calculate the communication size between the two.
|
||||||
|
"""
|
||||||
|
# Keep tracking the communication size between parent and child
|
||||||
|
comm_size = 0
|
||||||
|
# Keep tracking all the counted node
|
||||||
|
visited_nodes = set()
|
||||||
|
# Go through all nodes in the child partition
|
||||||
|
# If a node has input nodes from the parent partition,
|
||||||
|
# the output size of those input nodes will be counted
|
||||||
|
# and added to comm_size
|
||||||
|
parent_node_names = [n.name for n in parent_partition.graph.nodes]
|
||||||
|
for node in child_partition.graph.nodes:
|
||||||
|
input_nodes: Dict[Node, None] = {}
|
||||||
|
map_arg(node.args, lambda n: input_nodes.setdefault(n))
|
||||||
|
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
|
||||||
|
for n in input_nodes:
|
||||||
|
if n.name in parent_node_names and n not in visited_nodes:
|
||||||
|
comm_size += n.meta['tensor_meta'].numel
|
||||||
|
visited_nodes.add(n)
|
||||||
|
return comm_size
|
|
@ -0,0 +1,46 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import colossalai
|
||||||
|
import colossalai.nn as col_nn
|
||||||
|
from torch.fx import symbolic_trace
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
|
||||||
|
from colossalai.fx.passes.utils import get_comm_size
|
||||||
|
|
||||||
|
MODEL_DIM = 16
|
||||||
|
BATCH_SIZE = 8
|
||||||
|
PIPELINE_SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = torch.nn.Linear(dim, dim)
|
||||||
|
self.linear2 = torch.nn.Linear(dim, dim)
|
||||||
|
self.linear3 = torch.nn.Linear(dim, dim)
|
||||||
|
self.linear4 = torch.nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear1(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
x = self.linear3(x)
|
||||||
|
x = self.linear4(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_comm_size_compute():
|
||||||
|
model = MLP(MODEL_DIM)
|
||||||
|
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM)
|
||||||
|
gm = symbolic_trace(model)
|
||||||
|
MetaInfoProp(gm).run(input_sample)
|
||||||
|
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
||||||
|
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||||
|
submodule_list = list(split_model.children())
|
||||||
|
comm_size = get_comm_size(submodule_list[0], submodule_list[1])
|
||||||
|
# the shape of tensor send from partition 0 to partition 1 is (8, 16)
|
||||||
|
assert comm_size == 128
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_comm_size_compute()
|
|
@ -0,0 +1,35 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import colossalai
|
||||||
|
import colossalai.nn as col_nn
|
||||||
|
from torch.fx import symbolic_trace
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||||
|
|
||||||
|
BATCH_SIZE = 2
|
||||||
|
DIM_IN = 4
|
||||||
|
DIM_OUT = 16
|
||||||
|
|
||||||
|
|
||||||
|
def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
|
||||||
|
assert meta_info_spec.shape == orig_tensor.shape
|
||||||
|
assert meta_info_spec.dtype == orig_tensor.dtype
|
||||||
|
assert meta_info_spec.requires_grad == orig_tensor.requires_grad
|
||||||
|
assert meta_info_spec.stride == orig_tensor.stride()
|
||||||
|
assert meta_info_spec.numel == orig_tensor.numel()
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_info_prop():
|
||||||
|
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||||
|
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
|
||||||
|
orig_output = model(input_sample)
|
||||||
|
gm = symbolic_trace(model)
|
||||||
|
MetaInfoProp(gm).run(input_sample)
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
if node.op == 'placeholder':
|
||||||
|
meta_check(node.meta['tensor_meta'], input_sample)
|
||||||
|
if node.op == 'output':
|
||||||
|
meta_check(node.meta['tensor_meta'], orig_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_meta_info_prop()
|
Loading…
Reference in New Issue