[autoparallel] find repeat blocks (#2854)

* [autoparallel] find repeat blocks

* polish

* polish

* polish
pull/2912/head
YuliangLiu0306 2023-02-23 17:28:19 +08:00 committed by GitHub
parent 2e16f842a9
commit 0f392d7403
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 229 additions and 3 deletions

View File

@ -1,13 +1,16 @@
import copy
import operator import operator
import warnings import warnings
from functools import reduce from functools import reduce
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import torch import torch
from torch.fx.node import Node
from torch.utils._pytree import tree_map
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx.node import Node
from ..constants import INFINITY_COST from ..constants import INFINITY_COST
@ -88,3 +91,116 @@ def generate_resharding_costs(nodes: List[Node],
resharding_cost = INFINITY_COST resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost) resharding_costs[input_node].append(resharding_cost)
return resharding_costs return resharding_costs
def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20):
'''
Find the largest repeat blocks in the graph, whose length is larger than the threshold.
Args:
gm (GraphModule): the graph module to be analyzed.
common_length_threshold (int): the threshold of the repeat block length.
'''
# graph = gm.graph
def _process_args(args):
new_args = []
for arg in args:
if hasattr(arg, '_meta_data'):
meta_data = arg._meta_data
else:
meta_data = arg
def _process_arg(data):
if isinstance(data, torch.Tensor):
data = data.size()
elif isinstance(data, slice):
data = (data.start, data.step, data.stop)
return data
new_meta_data = tree_map(_process_arg, meta_data)
new_args.append(new_meta_data)
return new_args
def _all_equal(check_list, check_fn):
base_value = check_list[-1]
for e in check_list:
if not check_fn(e, base_value):
return False
return True
def _check_node_list_equal(l1, l2):
if len(l1) != len(l2):
return False
for node1, node2 in zip(l1, l2):
if hash(node1.hash_key) != hash(node2.hash_key):
return False
return True
def _check_node_equal(node1, node2):
if hash(node1.hash_key) == hash(node2.hash_key):
return True
return False
for index, node in enumerate(node_list):
if node.op == 'call_module':
target = node.target
submod = root_module.get_submodule(target)
submod_type = type(submod)
target = submod_type
else:
target = node.target
new_args = _process_args(node.args)
if node.op != 'get_attr':
hash_key = (node.op, target, *new_args)
else:
hash_key = (node.op,)
setattr(node, 'hash_key', hash_key)
hash_value_to_node_dict = {}
for index, node in enumerate(node_list):
hash_value = hash(node.hash_key)
if hash_value not in hash_value_to_node_dict:
hash_value_to_node_dict[hash_value] = []
hash_value_to_node_dict[hash_value].append(index)
# node_list = list(graph.nodes)
node_list_start = 0
max_common_length = common_length_threshold
common_blocks_index = []
for index, node in enumerate(node_list):
# the comparison will be triggered if a common node appears
if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2:
start_index_list = hash_value_to_node_dict[hash(node.hash_key)]
check_block_list = [node_list[start:start + max_common_length] for start in start_index_list]
common_label = True
if not _all_equal(check_block_list, _check_node_list_equal):
common_label = False
if common_label:
common_blocks_index = copy.deepcopy(start_index_list)
max_step = len(node_list) - common_blocks_index[-1] - max_common_length - 1
for i in range(max_step):
# add assertion to avoid out of index
next_node_list = [node_list[index + max_common_length + i] for index in start_index_list]
if not _all_equal(next_node_list, _check_node_equal):
max_step = i
break
max_common_length += max_step
node_list_start += max_common_length
# recover common subgraph from the index
common_blocks = []
for start in common_blocks_index:
common_blocks.append(node_list[start:start + max_common_length])
return common_blocks

View File

@ -0,0 +1,110 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.fx import GraphModule
from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
NUM_REPEAT_BLOCKS = 4
BATCH_SIZE = 1
SEQ_LENGTH = 32
HIDDEN_DIM = 384
class RepeatBlock(nn.Module):
def __init__(self, intermediate_size, hidden_size):
super().__init__()
self.c_fc = Conv1D(intermediate_size, hidden_size)
self.c_proj = Conv1D(hidden_size, intermediate_size)
self.act = torch.nn.ReLU()
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
return hidden_states
class RepeatModel(nn.Module):
def __init__(self, intermediate_size, hidden_size, num_layers):
super().__init__()
self.blocks = nn.ModuleList([RepeatBlock(intermediate_size, hidden_size) for i in range(num_layers)])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class NonRepeatBlock(nn.Module):
def __init__(self, intermediate_size, hidden_size, layer_index):
super().__init__()
intermediate_size //= (layer_index + 1)
self.c_fc = Conv1D(intermediate_size, hidden_size)
self.c_proj = Conv1D(hidden_size, intermediate_size)
self.act = torch.nn.ReLU()
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
return hidden_states
class NonRepeatModel(nn.Module):
def __init__(self, intermediate_size, hidden_size, num_layers):
super().__init__()
self.blocks = nn.ModuleList([NonRepeatBlock(intermediate_size, hidden_size, i) for i in range(num_layers)])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [RepeatModel, NonRepeatModel])
def test_repeat_blocks(model_cls):
model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS)
tracer = ColoTracer()
input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
node_list = list(graph.nodes)
root_module = graph.owning_module
common_blocks = find_repeat_blocks(node_list, root_module, common_length_threshold=10)
total_num_nodes = len(list(graph.nodes))
# remove the input placeholder node and the output node
num_repeat_nodes_per_block = (total_num_nodes - 2) // NUM_REPEAT_BLOCKS
for common_block in common_blocks:
print(common_block)
if model_cls == RepeatModel:
assert len(common_blocks) == NUM_REPEAT_BLOCKS
assert len(common_blocks[0]) == num_repeat_nodes_per_block
elif model_cls == NonRepeatModel:
assert len(common_blocks) == 0
if __name__ == '__main__':
test_repeat_blocks()