[autoparellel]add strategies constructor (#1505)

* [autoparellel]add strategies constructor

* remove duplicated strategies

* polish code

* adapt cost graph with StrategiesConstructor

* polish
pull/1522/head
YuliangLiu0306 2022-08-30 16:32:09 +08:00 committed by GitHub
parent a0436a62ee
commit 3345c6d352
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 633 additions and 32 deletions

View File

@ -0,0 +1,22 @@
import torch
import operator
__all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP',
'LINEAR_FUNC_OP'
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d
]
CONV_FUNC_OP = [
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
]
LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]

View File

@ -494,3 +494,10 @@ class ConvHandler(OperatorHandler):
self.split_1d_parallel_on_in_channel(0, 1)
return self.strategies_vector
CONV_STRATEGIES_LIST = [
'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R',
'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1',
'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'
]

View File

@ -39,11 +39,11 @@ class CostGraph:
dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes:
node_pair = (src_node, dst_node)
src_index = strategies_vector.predecessor_nodes.index(src_node)
# src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(src_node.stategy_vector)):
edge_cost[(i, j)] = strategies_vector[i].resharding_costs[src_index][j]
for j in range(len(src_node.strategies_vector)):
edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j]
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
@ -83,33 +83,19 @@ class CostGraph:
merge_map = {}
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
resharding_costs = strategy.resharding_costs
resharding_cost_for_src = resharding_costs[src_node_index]
resharding_cost_for_src = resharding_costs[src_node]
lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src))
merge_map[dst_strate_index] = lowest_cost_index
# extra_node_cost for dst node
extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])]
self.extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])]
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
target_strate_index = merge_map[dst_strate_index]
extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node_index][
self.extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node][
target_strate_index]
if src_node in extra_node_costs:
extra_node_costs[dst_node][dst_strate_index] += extra_node_costs[src_node][target_strate_index]
# connect dst node and parents of src node
dst_node.parents.remove(src_node)
src_node.children.remove(dst_node)
node_pair_to_remove = [(src_node, dst_node)]
for parent_node in src_node.parents:
if parent_node not in dst_node.parents:
dst_node.parents.append(parent)
if dst_node not in parent_node.children:
parent_node.children.append(dst_node)
# remove src node from cost graph when src node has no consumer.
if len(src_node.children) == 0:
parent_node.children.remove(src_node)
node_pair = (parent_node, src_node)
self.edge_costs.pop(node_pair)
if src_node in self.extra_node_costs:
self.extra_node_costs[dst_node][dst_strate_index] += self.extra_node_costs[src_node][
target_strate_index]
# add new node pair to cost graph
for parent_node in src_node.parents:
@ -121,9 +107,24 @@ class CostGraph:
for i in range(self.node_lens[dst_node]):
for j in range(self.node_lens[parent_node]):
src_strate_index = merge_map[i]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(j, src_strate_index)]
edge_cost[(j, i)] = self.edge_costs[old_node_pair][(j, src_strate_index)]
self.edge_costs[new_node_pair] = edge_cost
# connect dst node and parents of src node
dst_node.parents.remove(src_node)
src_node.children.remove(dst_node)
self.edge_costs.pop((src_node, dst_node))
for parent_node in src_node.parents:
if parent_node not in dst_node.parents:
dst_node.parents.append(parent_node)
if dst_node not in parent_node.children:
parent_node.children.append(dst_node)
# remove src node from cost graph when src node has no consumer.
if len(src_node.children) == 0:
parent_node.children.remove(src_node)
node_pair = (parent_node, src_node)
self.edge_costs.pop(node_pair)
def simplify_graph(self):
if not self.simplify:
return

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from torch.fx.node import Node
from typing import Dict
from typing import Dict, List
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
@ -56,7 +56,7 @@ class OperatorHandler(ABC):
"""
pass
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec:
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding.
@ -84,7 +84,9 @@ class OperatorHandler(ABC):
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
strategy.output_sharding_spec, input_spec)
input_sharding_spec, input_spec)
resharding_costs[input_node].append(resharding_cost)
return resharding_costs

View File

@ -1,7 +1,8 @@
from dataclasses import dataclass
from colossalai.tensor.sharding_spec import ShardingSpec
from typing import Dict, List
from typing import Dict, List, Union, Tuple
from torch.fx.node import Node
from .constants import *
__all__ = ['ShardingStrategy', 'StrategiesVector']
@ -25,12 +26,15 @@ class ShardingStrategy:
'''
name: str
output_sharding_spec: ShardingSpec
# TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
compute_cost: float = 0.
communication_cost: float = 0.
memory_cost: float = 0.
resharding_costs: Dict[int, List[float]] = None
input_shardings: ShardingSpec = None
resharding_costs: Dict[Node, List[float]] = None
# sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
# Therefore, we could process them at the specific op(operator.getitem)
input_shardings: List[ShardingSpec] = None
class StrategiesVector(list):
@ -46,8 +50,23 @@ class StrategiesVector(list):
super().__init__()
self.node = node
# fetch its input and output nodes
# TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys())
self.successor_nodes = list(node.users.keys())
def check_merge(self):
pass
merge_label = False
if self.node.op == 'call_module':
target = self.node.target
root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target)
submod_type = type(submod)
# merge elementwise module node into following nodes
if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True
if self.node.op == 'call_function':
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
return merge_label

View File

@ -0,0 +1,355 @@
from torch.fx import Graph, Node
from colossalai.tensor.sharding_spec import ShardingSpec
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .conv_handler import ConvHandler
from .constants import *
from copy import deepcopy
import math
import torch
import operator
from typing import Dict, List
class StrategiesConstructor:
def __init__(self, graph, device_mesh, shape_consistency_manager, solver_options):
self.graph = graph
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.shape_consistency_manager = shape_consistency_manager
self.solver_options = solver_options
def _generate_sharding_spec(self, node: Node, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
meta_tensor = node._meta_data
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=meta_tensor.shape,
dim_partition_dict=dim_partition_dict)
return sharding_spec
def _generate_resharding_costs(self, input_nodes, target_sharding_specs):
'''
Compute the resharding costs with this specific strategy.
Argument:
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
resharding_costs = {}
for input_node, target_sharding_spec in zip(input_nodes, target_sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
input_sharding_spec, target_sharding_spec)
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
'''
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)
def build_strategies_and_cost(self):
for node in self.nodes:
strategies_vector = StrategiesVector(node)
# placeholder node
if node.op == 'placeholder':
# For placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
if self.solver_options['fast_mode']:
# create sharding strategy for placeholder
name = 'Replica Placeholder'
dim_partition_dict = {}
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_placeholder = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_placeholder)
# get_attr node
if node.op == 'get_attr':
# Same as placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
if self.solver_options['fast_mode']:
# create sharding strategy for get_attr
name = 'Replica Attribute'
dim_partition_dict = {}
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_attribute)
# call_module node
if node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
# conv module
if submod_type in CONV_MODULE_OP:
# use ConvHandler to create sharding strategies for conv module node
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
self.shape_consistency_manager)
conv_handler.register_strategy()
# linear module
elif submod_type in LINEAR_MODULE_OP:
# use DotHandler to create sharding strategies for linear module node
dot_handler = DotHandler(node, self.device_mesh, strategies_vector, self.shape_consistency_manager)
dot_handler.register_strategy()
# element-wise module
elif submod_type in ELEMENTWISE_MODULE_OP:
# create sharding strategy for element-wise module
assert len(strategies_vector.predecessor_nodes
) == 1, f'Temporally, we just support single input element-wise op.'
input_node = strategies_vector.predecessor_nodes[0]
# For element-wise module, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise module.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [
cost if cost == 0 else math.inf for cost in resharding_costs[input_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# other module
else:
raise RuntimeError(f'{submod_type} module is NOT supported now.')
# call_function node
if node.op == 'call_function':
target = node.target
# conv function
if target in CONV_FUNC_OP:
# use ConvHandler to create sharding strategies for conv node
# TODO: the operator_handler does NOT support function node processing now.
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
self.shape_consistency_manager)
conv_handler.register_strategy()
# linear function
elif target in LINEAR_FUNC_OP:
# use DotHandler to create sharding strategies for linear node
# TODO: the operator_handler does NOT support function node processing now.
linear_handler = DotHandler(node, self.device_mesh, strategies_vector,
self.shape_consistency_manager)
linear_handler.register_strategy()
# element-wise function
elif target in ELEMENTWISE_FUNC_OP:
# TODO: integrate element-wise func and module together
# create sharding strategy for element-wise function
assert len(strategies_vector.predecessor_nodes
) == 1, f'Temporally, we just support single input element-wise op.'
input_node = strategies_vector.predecessor_nodes[0]
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise function.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [
0 if cost == 0 else math.inf for cost in resharding_costs[input_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# torch.var_mean
elif target == torch.var_mean:
dim = node.kwargs['dim']
input_tensor_node = strategies_vector.predecessor_nodes[0]
for strategy in input_tensor_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
entire_shape_input = input_sharding_spec.entire_shape
dim_partition_dict_input = input_sharding_spec.dim_partition_dict
name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})'
if dim in dim_partition_dict_input:
# We need to make the action dimension in replicate status
dim_partition_dict_for_input = deepcopy(dim_partition_dict_input)
dim_partition_dict_for_input.pop(dim)
new_input_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_input,
dim_partition_dict=dim_partition_dict_for_input)
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[new_input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[new_input_sharding_spec])
else:
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partion_dict=dim_partition_dict_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# operator.getitem
elif target == operator.getitem:
index = node.args[1]
input_tensor_node = strategies_vector.predecessor_nodes[0]
for strategy in input_tensor_node.strategies_vector:
input_sharding_spec = input_tensor_node.output_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_tensor_node.output_sharding_spec])
strategies_vector.append(sharding_strategy)
# other function
else:
raise RuntimeError(f'{target} function is NOT supported now.')
# output node
if node.op == 'output':
if self.solver_options['fast_mode']:
# create sharding strategy for output
name = 'Replica Output'
input_nodes = strategies_vector.predecessor_nodes
input_sharding_specs = []
for input_node in input_nodes:
dim_partition_dict_for_input = {}
entire_shape = input_node._meta_data.shape
sharding_spec = ShardingSpec(self.device_mesh,
entire_shape,
dim_partition_dict=dim_partition_dict_for_input)
input_sharding_specs.append(sharding_spec)
dim_partition_dict = {}
output_sharding_spec = input_sharding_specs
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
input_sharding_specs)
sharding_strategy_attribute = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost,
resharding_costs=resharding_costs)
strategies_vector.append(sharding_strategy_attribute)
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector

View File

@ -0,0 +1,97 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from copy import deepcopy
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
self.relu = nn.ReLU()
def forward(self, x):
x = x * 2
x = self.conv1(x)
x = x / 2
x = self.relu(x)
return x
def test_cost_graph():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv1, 2), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%truediv,), kwargs = {})
# return relu
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = {'fast_mode': True}
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor.build_strategies_and_cost()
# (x, mul): {(0, 0): 0}
# (mul, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0}
# (conv1, truediv): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): 0, (8, 0): inf, (9, 0): 0, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): inf, (5, 1): 0, (6, 1): 0, (7, 1): inf, (8, 1): 0, (9, 1): inf, (10, 1): 0, (11, 1): 0, (12, 1): 0, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): 0, (11, 2): 0, (12, 2): 0, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): 0, (11, 3): 0, (12, 3): 0, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): 0, (11, 4): 0, (12, 4): 0, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): 0, (11, 5): 0, (12, 5): 0, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): 0, (11, 7): 0, (12, 7): 0, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0, (9, 8): inf, (10, 8): 0, (11, 8): 0, (12, 8): 0, (13, 8): inf, (14, 8): 0}
# (truediv, relu): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): inf, (5, 0): 0, (6, 0): 0, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): 0, (5, 1): inf, (6, 1): 0, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): 0, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): 0, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): 0, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): 0, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): 0, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0}
# (relu, output): {(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002}
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
# construct all node pairs
all_node_pairs = []
for node in graph.nodes:
if node.op == 'output':
continue
all_node_pairs.append((node, node.next))
for node_pair in all_node_pairs:
assert node_pair in cost_graph.edge_costs
# construct merged node pairs
merged_node_pairs = []
node_list = list(graph.nodes)
# add (x, conv) and (conv, output) into check node pairs
merged_node_pairs.append((node_list[0], node_list[2]))
merged_node_pairs.append((node_list[2], node_list[-1]))
# (x, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0}
# (conv1, output): {(0, 0): inf, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf}
cost_graph.simplify_graph()
for node_pair in all_node_pairs:
if node_pair in merged_node_pairs:
assert node_pair in cost_graph.edge_costs
else:
assert node_pair not in cost_graph.edge_costs
if __name__ == '__main__':
test_cost_graph()

View File

@ -0,0 +1,98 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from copy import deepcopy
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
def forward(self, x):
x = x * 2
x = self.conv(x)
return x
def test_strategies_constructor():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = {'fast_mode': True}
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
assert strategies_constructor.leaf_strategies == []
assert strategies_constructor.strategy_map == {}
strategies_constructor.build_strategies_and_cost()
# check leaf_strategies
# In fast mode, placeholder node only has replica strategy.
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]'
# Third node is conv.
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
for strategy in strategies_constructor.leaf_strategies[2]:
conv_check_list.remove(strategy.name)
assert len(conv_check_list) == 0
# In fast mode, output node only has replica strategy.
assert strategies_constructor.leaf_strategies[3][0].name == 'Replica Output'
# check strategy_map
nodes = [node for node in graph.nodes]
# In fast mode, placeholder node only has replica strategy.
x = nodes[0]
assert strategies_constructor.strategy_map[x][0].name == 'Replica Placeholder'
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
mul = nodes[1]
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]'
# Third node is conv.
conv = nodes[2]
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
for strategy in strategies_constructor.strategy_map[conv]:
conv_check_list.remove(strategy.name)
assert len(conv_check_list) == 0
# In fast mode, output node only has replica strategy.
output = nodes[3]
assert strategies_constructor.strategy_map[output][0].name == 'Replica Output'
if __name__ == '__main__':
test_strategies_constructor()