[autoparellel]add strategies constructor (#1505)

* [autoparellel]add strategies constructor

* remove duplicated strategies

* polish code

* adapt cost graph with StrategiesConstructor

* polish
pull/1522/head
YuliangLiu0306 2022-08-30 16:32:09 +08:00 committed by GitHub
parent a0436a62ee
commit 3345c6d352
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 633 additions and 32 deletions

View File

@ -0,0 +1,22 @@
import torch
import operator
__all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP',
'LINEAR_FUNC_OP'
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d
]
CONV_FUNC_OP = [
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
]
LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]

View File

@ -494,3 +494,10 @@ class ConvHandler(OperatorHandler):
self.split_1d_parallel_on_in_channel(0, 1) self.split_1d_parallel_on_in_channel(0, 1)
return self.strategies_vector return self.strategies_vector
CONV_STRATEGIES_LIST = [
'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R',
'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1',
'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'
]

View File

@ -39,11 +39,11 @@ class CostGraph:
dst_node = strategies_vector.node dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes: for src_node in strategies_vector.predecessor_nodes:
node_pair = (src_node, dst_node) node_pair = (src_node, dst_node)
src_index = strategies_vector.predecessor_nodes.index(src_node) # src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {} edge_cost = {}
for i in range(len(strategies_vector)): for i in range(len(strategies_vector)):
for j in range(len(src_node.stategy_vector)): for j in range(len(src_node.strategies_vector)):
edge_cost[(i, j)] = strategies_vector[i].resharding_costs[src_index][j] edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j]
self.edge_costs[node_pair] = edge_cost self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node # add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes) setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
@ -83,33 +83,19 @@ class CostGraph:
merge_map = {} merge_map = {}
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector): for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
resharding_costs = strategy.resharding_costs resharding_costs = strategy.resharding_costs
resharding_cost_for_src = resharding_costs[src_node_index] resharding_cost_for_src = resharding_costs[src_node]
lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src)) lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src))
merge_map[dst_strate_index] = lowest_cost_index merge_map[dst_strate_index] = lowest_cost_index
# extra_node_cost for dst node # extra_node_cost for dst node
extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])] self.extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])]
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector): for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
target_strate_index = merge_map[dst_strate_index] target_strate_index = merge_map[dst_strate_index]
extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node_index][ self.extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node][
target_strate_index] target_strate_index]
if src_node in extra_node_costs: if src_node in self.extra_node_costs:
extra_node_costs[dst_node][dst_strate_index] += extra_node_costs[src_node][target_strate_index] self.extra_node_costs[dst_node][dst_strate_index] += self.extra_node_costs[src_node][
target_strate_index]
# connect dst node and parents of src node
dst_node.parents.remove(src_node)
src_node.children.remove(dst_node)
node_pair_to_remove = [(src_node, dst_node)]
for parent_node in src_node.parents:
if parent_node not in dst_node.parents:
dst_node.parents.append(parent)
if dst_node not in parent_node.children:
parent_node.children.append(dst_node)
# remove src node from cost graph when src node has no consumer.
if len(src_node.children) == 0:
parent_node.children.remove(src_node)
node_pair = (parent_node, src_node)
self.edge_costs.pop(node_pair)
# add new node pair to cost graph # add new node pair to cost graph
for parent_node in src_node.parents: for parent_node in src_node.parents:
@ -121,9 +107,24 @@ class CostGraph:
for i in range(self.node_lens[dst_node]): for i in range(self.node_lens[dst_node]):
for j in range(self.node_lens[parent_node]): for j in range(self.node_lens[parent_node]):
src_strate_index = merge_map[i] src_strate_index = merge_map[i]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(j, src_strate_index)] edge_cost[(j, i)] = self.edge_costs[old_node_pair][(j, src_strate_index)]
self.edge_costs[new_node_pair] = edge_cost self.edge_costs[new_node_pair] = edge_cost
# connect dst node and parents of src node
dst_node.parents.remove(src_node)
src_node.children.remove(dst_node)
self.edge_costs.pop((src_node, dst_node))
for parent_node in src_node.parents:
if parent_node not in dst_node.parents:
dst_node.parents.append(parent_node)
if dst_node not in parent_node.children:
parent_node.children.append(dst_node)
# remove src node from cost graph when src node has no consumer.
if len(src_node.children) == 0:
parent_node.children.remove(src_node)
node_pair = (parent_node, src_node)
self.edge_costs.pop(node_pair)
def simplify_graph(self): def simplify_graph(self):
if not self.simplify: if not self.simplify:
return return

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch.fx.node import Node from torch.fx.node import Node
from typing import Dict from typing import Dict, List
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
@ -56,7 +56,7 @@ class OperatorHandler(ABC):
""" """
pass pass
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec: def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
""" """
Generate the sharding spec of the tensor based on the given dim_partition_dict Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding. where the key is the tensor dimension and the value is the mesh dimension for sharding.
@ -84,7 +84,9 @@ class OperatorHandler(ABC):
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input): for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
resharding_costs[input_node] = [] resharding_costs[input_node] = []
for strategy in input_node.strategies_vector: for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency( _, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
strategy.output_sharding_spec, input_spec) input_sharding_spec, input_spec)
resharding_costs[input_node].append(resharding_cost) resharding_costs[input_node].append(resharding_cost)
return resharding_costs return resharding_costs

View File

@ -1,7 +1,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from typing import Dict, List from typing import Dict, List, Union, Tuple
from torch.fx.node import Node from torch.fx.node import Node
from .constants import *
__all__ = ['ShardingStrategy', 'StrategiesVector'] __all__ = ['ShardingStrategy', 'StrategiesVector']
@ -25,12 +26,15 @@ class ShardingStrategy:
''' '''
name: str name: str
output_sharding_spec: ShardingSpec # TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
compute_cost: float = 0. compute_cost: float = 0.
communication_cost: float = 0. communication_cost: float = 0.
memory_cost: float = 0. memory_cost: float = 0.
resharding_costs: Dict[int, List[float]] = None resharding_costs: Dict[Node, List[float]] = None
input_shardings: ShardingSpec = None # sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
# Therefore, we could process them at the specific op(operator.getitem)
input_shardings: List[ShardingSpec] = None
class StrategiesVector(list): class StrategiesVector(list):
@ -46,8 +50,23 @@ class StrategiesVector(list):
super().__init__() super().__init__()
self.node = node self.node = node
# fetch its input and output nodes # fetch its input and output nodes
# TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys()) self.predecessor_nodes = list(node._input_nodes.keys())
self.successor_nodes = list(node.users.keys()) self.successor_nodes = list(node.users.keys())
def check_merge(self): def check_merge(self):
pass merge_label = False
if self.node.op == 'call_module':
target = self.node.target
root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target)
submod_type = type(submod)
# merge elementwise module node into following nodes
if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True
if self.node.op == 'call_function':
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
return merge_label

View File

@ -0,0 +1,355 @@
from torch.fx import Graph, Node
from colossalai.tensor.sharding_spec import ShardingSpec
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .conv_handler import ConvHandler
from .constants import *
from copy import deepcopy
import math
import torch
import operator
from typing import Dict, List
class StrategiesConstructor:
def __init__(self, graph, device_mesh, shape_consistency_manager, solver_options):
self.graph = graph
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.shape_consistency_manager = shape_consistency_manager
self.solver_options = solver_options
def _generate_sharding_spec(self, node: Node, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
meta_tensor = node._meta_data
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=meta_tensor.shape,
dim_partition_dict=dim_partition_dict)
return sharding_spec
def _generate_resharding_costs(self, input_nodes, target_sharding_specs):
'''
Compute the resharding costs with this specific strategy.
Argument:
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
resharding_costs = {}
for input_node, target_sharding_spec in zip(input_nodes, target_sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
input_sharding_spec, target_sharding_spec)
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
'''
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)
def build_strategies_and_cost(self):
for node in self.nodes:
strategies_vector = StrategiesVector(node)
# placeholder node
if node.op == 'placeholder':
# For placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
if self.solver_options['fast_mode']:
# create sharding strategy for placeholder
name = 'Replica Placeholder'
dim_partition_dict = {}
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_placeholder = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_placeholder)
# get_attr node
if node.op == 'get_attr':
# Same as placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
if self.solver_options['fast_mode']:
# create sharding strategy for get_attr
name = 'Replica Attribute'
dim_partition_dict = {}
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_attribute)
# call_module node
if node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
# conv module
if submod_type in CONV_MODULE_OP:
# use ConvHandler to create sharding strategies for conv module node
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
self.shape_consistency_manager)
conv_handler.register_strategy()
# linear module
elif submod_type in LINEAR_MODULE_OP:
# use DotHandler to create sharding strategies for linear module node
dot_handler = DotHandler(node, self.device_mesh, strategies_vector, self.shape_consistency_manager)
dot_handler.register_strategy()
# element-wise module
elif submod_type in ELEMENTWISE_MODULE_OP:
# create sharding strategy for element-wise module
assert len(strategies_vector.predecessor_nodes
) == 1, f'Temporally, we just support single input element-wise op.'
input_node = strategies_vector.predecessor_nodes[0]
# For element-wise module, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise module.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [
cost if cost == 0 else math.inf for cost in resharding_costs[input_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# other module
else:
raise RuntimeError(f'{submod_type} module is NOT supported now.')
# call_function node
if node.op == 'call_function':
target = node.target
# conv function
if target in CONV_FUNC_OP:
# use ConvHandler to create sharding strategies for conv node
# TODO: the operator_handler does NOT support function node processing now.
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
self.shape_consistency_manager)
conv_handler.register_strategy()
# linear function
elif target in LINEAR_FUNC_OP:
# use DotHandler to create sharding strategies for linear node
# TODO: the operator_handler does NOT support function node processing now.
linear_handler = DotHandler(node, self.device_mesh, strategies_vector,
self.shape_consistency_manager)
linear_handler.register_strategy()
# element-wise function
elif target in ELEMENTWISE_FUNC_OP:
# TODO: integrate element-wise func and module together
# create sharding strategy for element-wise function
assert len(strategies_vector.predecessor_nodes
) == 1, f'Temporally, we just support single input element-wise op.'
input_node = strategies_vector.predecessor_nodes[0]
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise function.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [
0 if cost == 0 else math.inf for cost in resharding_costs[input_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# torch.var_mean
elif target == torch.var_mean:
dim = node.kwargs['dim']
input_tensor_node = strategies_vector.predecessor_nodes[0]
for strategy in input_tensor_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
entire_shape_input = input_sharding_spec.entire_shape
dim_partition_dict_input = input_sharding_spec.dim_partition_dict
name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})'
if dim in dim_partition_dict_input:
# We need to make the action dimension in replicate status
dim_partition_dict_for_input = deepcopy(dim_partition_dict_input)
dim_partition_dict_for_input.pop(dim)
new_input_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_input,
dim_partition_dict=dim_partition_dict_for_input)
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[new_input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[new_input_sharding_spec])
else:
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partion_dict=dim_partition_dict_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# operator.getitem
elif target == operator.getitem:
index = node.args[1]
input_tensor_node = strategies_vector.predecessor_nodes[0]
for strategy in input_tensor_node.strategies_vector:
input_sharding_spec = input_tensor_node.output_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_tensor_node.output_sharding_spec])
strategies_vector.append(sharding_strategy)
# other function
else:
raise RuntimeError(f'{target} function is NOT supported now.')
# output node
if node.op == 'output':
if self.solver_options['fast_mode']:
# create sharding strategy for output
name = 'Replica Output'
input_nodes = strategies_vector.predecessor_nodes
input_sharding_specs = []
for input_node in input_nodes:
dim_partition_dict_for_input = {}
entire_shape = input_node._meta_data.shape
sharding_spec = ShardingSpec(self.device_mesh,
entire_shape,
dim_partition_dict=dim_partition_dict_for_input)
input_sharding_specs.append(sharding_spec)
dim_partition_dict = {}
output_sharding_spec = input_sharding_specs
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
input_sharding_specs)
sharding_strategy_attribute = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost,
resharding_costs=resharding_costs)
strategies_vector.append(sharding_strategy_attribute)
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector

View File

@ -0,0 +1,97 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from copy import deepcopy
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
self.relu = nn.ReLU()
def forward(self, x):
x = x * 2
x = self.conv1(x)
x = x / 2
x = self.relu(x)
return x
def test_cost_graph():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv1, 2), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%truediv,), kwargs = {})
# return relu
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = {'fast_mode': True}
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor.build_strategies_and_cost()
# (x, mul): {(0, 0): 0}
# (mul, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0}
# (conv1, truediv): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): 0, (8, 0): inf, (9, 0): 0, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): inf, (5, 1): 0, (6, 1): 0, (7, 1): inf, (8, 1): 0, (9, 1): inf, (10, 1): 0, (11, 1): 0, (12, 1): 0, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): 0, (11, 2): 0, (12, 2): 0, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): 0, (11, 3): 0, (12, 3): 0, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): 0, (11, 4): 0, (12, 4): 0, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): 0, (11, 5): 0, (12, 5): 0, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): 0, (11, 7): 0, (12, 7): 0, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0, (9, 8): inf, (10, 8): 0, (11, 8): 0, (12, 8): 0, (13, 8): inf, (14, 8): 0}
# (truediv, relu): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): inf, (5, 0): 0, (6, 0): 0, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): 0, (5, 1): inf, (6, 1): 0, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): 0, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): 0, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): 0, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): 0, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): 0, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0}
# (relu, output): {(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002}
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
# construct all node pairs
all_node_pairs = []
for node in graph.nodes:
if node.op == 'output':
continue
all_node_pairs.append((node, node.next))
for node_pair in all_node_pairs:
assert node_pair in cost_graph.edge_costs
# construct merged node pairs
merged_node_pairs = []
node_list = list(graph.nodes)
# add (x, conv) and (conv, output) into check node pairs
merged_node_pairs.append((node_list[0], node_list[2]))
merged_node_pairs.append((node_list[2], node_list[-1]))
# (x, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0}
# (conv1, output): {(0, 0): inf, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf}
cost_graph.simplify_graph()
for node_pair in all_node_pairs:
if node_pair in merged_node_pairs:
assert node_pair in cost_graph.edge_costs
else:
assert node_pair not in cost_graph.edge_costs
if __name__ == '__main__':
test_cost_graph()

View File

@ -0,0 +1,98 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from copy import deepcopy
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
def forward(self, x):
x = x * 2
x = self.conv(x)
return x
def test_strategies_constructor():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = {'fast_mode': True}
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
assert strategies_constructor.leaf_strategies == []
assert strategies_constructor.strategy_map == {}
strategies_constructor.build_strategies_and_cost()
# check leaf_strategies
# In fast mode, placeholder node only has replica strategy.
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]'
# Third node is conv.
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
for strategy in strategies_constructor.leaf_strategies[2]:
conv_check_list.remove(strategy.name)
assert len(conv_check_list) == 0
# In fast mode, output node only has replica strategy.
assert strategies_constructor.leaf_strategies[3][0].name == 'Replica Output'
# check strategy_map
nodes = [node for node in graph.nodes]
# In fast mode, placeholder node only has replica strategy.
x = nodes[0]
assert strategies_constructor.strategy_map[x][0].name == 'Replica Placeholder'
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
mul = nodes[1]
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]'
# Third node is conv.
conv = nodes[2]
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
for strategy in strategies_constructor.strategy_map[conv]:
conv_check_list.remove(strategy.name)
assert len(conv_check_list) == 0
# In fast mode, output node only has replica strategy.
output = nodes[3]
assert strategies_constructor.strategy_map[output][0].name == 'Replica Output'
if __name__ == '__main__':
test_strategies_constructor()