mirror of https://github.com/hpcaitech/ColossalAI
[fx]get communication size between partitions (#1224)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx]get communication size between partitions.
* polish
pull/1222/head^2
parent
4951f7d80c
commit
2b7dca44b5
|
@ -0,0 +1,101 @@
|
|||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node, map_aggregate
|
||||
from typing import Any, Tuple, NamedTuple, Optional, Dict
|
||||
from functools import reduce
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class TensorMetadata(NamedTuple):
|
||||
# TensorMetadata is a structure containing pertinent information
|
||||
# about a tensor within a PyTorch program.
|
||||
|
||||
shape: torch.Size
|
||||
dtype: torch.dtype
|
||||
requires_grad: bool
|
||||
stride: Tuple[int]
|
||||
numel: int
|
||||
# TODO: we can add a list of sharding spec here, and record the sharding
|
||||
# behaviour by appending sharding spec into list.
|
||||
|
||||
|
||||
def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
|
||||
"""
|
||||
Extract a TensorMetadata NamedTuple describing `result`.
|
||||
"""
|
||||
shape = result.shape
|
||||
dtype = result.dtype
|
||||
requires_grad = result.requires_grad
|
||||
stride = result.stride()
|
||||
numel = result.numel()
|
||||
|
||||
return TensorMetadata(shape, dtype, requires_grad, stride, numel)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class MetaInfoProp(torch.fx.Interpreter):
|
||||
"""
|
||||
Execute an FX graph Node-by-Node and
|
||||
record the shape and type of the result
|
||||
into the corresponding node.
|
||||
|
||||
Usage:
|
||||
BATCH_SIZE = 2
|
||||
DIM_IN = 4
|
||||
DIM_OUT = 16
|
||||
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
|
||||
orig_output = model(input_sample)
|
||||
gm = symbolic_trace(model)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
print(node.name, node.meta['tensor_meta'].dtype,
|
||||
node.meta['tensor_meta'].shape, node.meta['tensor_meta'].numel)
|
||||
|
||||
# output of above code is
|
||||
# input_1 torch.float32 torch.Size([2, 4]) 8
|
||||
# weight torch.float32 torch.Size([16, 4]) 64
|
||||
# bias torch.float32 torch.Size([16]) 16
|
||||
# linear torch.float32 torch.Size([2, 16]) 32
|
||||
# output torch.float32 torch.Size([2, 16]) 32
|
||||
Args:
|
||||
module (GraphModule): The module to be executed
|
||||
|
||||
"""
|
||||
|
||||
def run_node(self, n: Node) -> Any:
|
||||
result = super().run_node(n)
|
||||
|
||||
found_tensor = False
|
||||
|
||||
def extract_tensor_meta(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
nonlocal found_tensor
|
||||
found_tensor = True
|
||||
return _extract_tensor_metadata(obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
meta = map_aggregate(result, extract_tensor_meta)
|
||||
if found_tensor:
|
||||
n.meta['tensor_meta'] = meta
|
||||
else:
|
||||
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)
|
||||
|
||||
n.meta['type'] = type(result)
|
||||
return result
|
||||
|
||||
def propagate(self, *args):
|
||||
"""
|
||||
Run `module` via interpretation and return the result and
|
||||
record the shape and type of each node.
|
||||
|
||||
Args:
|
||||
*args (Tensor): the sample input.
|
||||
|
||||
Returns:
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
return super().run(*args)
|
|
@ -0,0 +1,27 @@
|
|||
import torch
|
||||
from typing import Dict, Set
|
||||
from torch.fx.node import Node, map_arg
|
||||
|
||||
|
||||
def get_comm_size(prev_partition, next_partition):
|
||||
"""Given two partitions (parent and child),
|
||||
calculate the communication size between the two.
|
||||
"""
|
||||
# Keep tracking the communication size between parent and child
|
||||
comm_size = 0
|
||||
# Keep tracking all the counted node
|
||||
visited_nodes = set()
|
||||
# Go through all nodes in the child partition
|
||||
# If a node has input nodes from the parent partition,
|
||||
# the output size of those input nodes will be counted
|
||||
# and added to comm_size
|
||||
parent_node_names = [n.name for n in parent_partition.graph.nodes]
|
||||
for node in child_partition.graph.nodes:
|
||||
input_nodes: Dict[Node, None] = {}
|
||||
map_arg(node.args, lambda n: input_nodes.setdefault(n))
|
||||
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
|
||||
for n in input_nodes:
|
||||
if n.name in parent_node_names and n not in visited_nodes:
|
||||
comm_size += n.meta['tensor_meta'].numel
|
||||
visited_nodes.add(n)
|
||||
return comm_size
|
|
@ -0,0 +1,46 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
from torch.fx import symbolic_trace
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
|
||||
from colossalai.fx.passes.utils import get_comm_size
|
||||
|
||||
MODEL_DIM = 16
|
||||
BATCH_SIZE = 8
|
||||
PIPELINE_SIZE = 2
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(dim, dim)
|
||||
self.linear2 = torch.nn.Linear(dim, dim)
|
||||
self.linear3 = torch.nn.Linear(dim, dim)
|
||||
self.linear4 = torch.nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
x = self.linear3(x)
|
||||
x = self.linear4(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_comm_size_compute():
|
||||
model = MLP(MODEL_DIM)
|
||||
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM)
|
||||
gm = symbolic_trace(model)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
||||
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||
submodule_list = list(split_model.children())
|
||||
comm_size = get_comm_size(submodule_list[0], submodule_list[1])
|
||||
# the shape of tensor send from partition 0 to partition 1 is (8, 16)
|
||||
assert comm_size == 128
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_comm_size_compute()
|
|
@ -0,0 +1,35 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
from torch.fx import symbolic_trace
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||
|
||||
BATCH_SIZE = 2
|
||||
DIM_IN = 4
|
||||
DIM_OUT = 16
|
||||
|
||||
|
||||
def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
|
||||
assert meta_info_spec.shape == orig_tensor.shape
|
||||
assert meta_info_spec.dtype == orig_tensor.dtype
|
||||
assert meta_info_spec.requires_grad == orig_tensor.requires_grad
|
||||
assert meta_info_spec.stride == orig_tensor.stride()
|
||||
assert meta_info_spec.numel == orig_tensor.numel()
|
||||
|
||||
|
||||
def test_meta_info_prop():
|
||||
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
|
||||
orig_output = model(input_sample)
|
||||
gm = symbolic_trace(model)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
meta_check(node.meta['tensor_meta'], input_sample)
|
||||
if node.op == 'output':
|
||||
meta_check(node.meta['tensor_meta'], orig_output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_meta_info_prop()
|
Loading…
Reference in New Issue