mirror of https://github.com/hpcaitech/ColossalAI
[autoparellel]add strategies constructor (#1505)
* [autoparellel]add strategies constructor * remove duplicated strategies * polish code * adapt cost graph with StrategiesConstructor * polishpull/1522/head
parent
a0436a62ee
commit
3345c6d352
|
@ -0,0 +1,22 @@
|
|||
import torch
|
||||
import operator
|
||||
|
||||
__all__ = [
|
||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP',
|
||||
'LINEAR_FUNC_OP'
|
||||
]
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
ELEMENTWISE_FUNC_OP = [
|
||||
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
|
||||
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
|
||||
]
|
||||
CONV_MODULE_OP = [
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose3d
|
||||
]
|
||||
CONV_FUNC_OP = [
|
||||
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
|
||||
]
|
||||
LINEAR_MODULE_OP = [torch.nn.Linear]
|
||||
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
|
@ -494,3 +494,10 @@ class ConvHandler(OperatorHandler):
|
|||
self.split_1d_parallel_on_in_channel(0, 1)
|
||||
|
||||
return self.strategies_vector
|
||||
|
||||
|
||||
CONV_STRATEGIES_LIST = [
|
||||
'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R',
|
||||
'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1',
|
||||
'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'
|
||||
]
|
||||
|
|
|
@ -39,11 +39,11 @@ class CostGraph:
|
|||
dst_node = strategies_vector.node
|
||||
for src_node in strategies_vector.predecessor_nodes:
|
||||
node_pair = (src_node, dst_node)
|
||||
src_index = strategies_vector.predecessor_nodes.index(src_node)
|
||||
# src_index = strategies_vector.predecessor_nodes.index(src_node)
|
||||
edge_cost = {}
|
||||
for i in range(len(strategies_vector)):
|
||||
for j in range(len(src_node.stategy_vector)):
|
||||
edge_cost[(i, j)] = strategies_vector[i].resharding_costs[src_index][j]
|
||||
for j in range(len(src_node.strategies_vector)):
|
||||
edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j]
|
||||
self.edge_costs[node_pair] = edge_cost
|
||||
# add parents and children attribute to node
|
||||
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
|
||||
|
@ -83,33 +83,19 @@ class CostGraph:
|
|||
merge_map = {}
|
||||
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
|
||||
resharding_costs = strategy.resharding_costs
|
||||
resharding_cost_for_src = resharding_costs[src_node_index]
|
||||
resharding_cost_for_src = resharding_costs[src_node]
|
||||
lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src))
|
||||
merge_map[dst_strate_index] = lowest_cost_index
|
||||
|
||||
# extra_node_cost for dst node
|
||||
extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])]
|
||||
self.extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])]
|
||||
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
|
||||
target_strate_index = merge_map[dst_strate_index]
|
||||
extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node_index][
|
||||
self.extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node][
|
||||
target_strate_index]
|
||||
if src_node in extra_node_costs:
|
||||
extra_node_costs[dst_node][dst_strate_index] += extra_node_costs[src_node][target_strate_index]
|
||||
|
||||
# connect dst node and parents of src node
|
||||
dst_node.parents.remove(src_node)
|
||||
src_node.children.remove(dst_node)
|
||||
node_pair_to_remove = [(src_node, dst_node)]
|
||||
for parent_node in src_node.parents:
|
||||
if parent_node not in dst_node.parents:
|
||||
dst_node.parents.append(parent)
|
||||
if dst_node not in parent_node.children:
|
||||
parent_node.children.append(dst_node)
|
||||
# remove src node from cost graph when src node has no consumer.
|
||||
if len(src_node.children) == 0:
|
||||
parent_node.children.remove(src_node)
|
||||
node_pair = (parent_node, src_node)
|
||||
self.edge_costs.pop(node_pair)
|
||||
if src_node in self.extra_node_costs:
|
||||
self.extra_node_costs[dst_node][dst_strate_index] += self.extra_node_costs[src_node][
|
||||
target_strate_index]
|
||||
|
||||
# add new node pair to cost graph
|
||||
for parent_node in src_node.parents:
|
||||
|
@ -121,9 +107,24 @@ class CostGraph:
|
|||
for i in range(self.node_lens[dst_node]):
|
||||
for j in range(self.node_lens[parent_node]):
|
||||
src_strate_index = merge_map[i]
|
||||
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(j, src_strate_index)]
|
||||
edge_cost[(j, i)] = self.edge_costs[old_node_pair][(j, src_strate_index)]
|
||||
self.edge_costs[new_node_pair] = edge_cost
|
||||
|
||||
# connect dst node and parents of src node
|
||||
dst_node.parents.remove(src_node)
|
||||
src_node.children.remove(dst_node)
|
||||
self.edge_costs.pop((src_node, dst_node))
|
||||
for parent_node in src_node.parents:
|
||||
if parent_node not in dst_node.parents:
|
||||
dst_node.parents.append(parent_node)
|
||||
if dst_node not in parent_node.children:
|
||||
parent_node.children.append(dst_node)
|
||||
# remove src node from cost graph when src node has no consumer.
|
||||
if len(src_node.children) == 0:
|
||||
parent_node.children.remove(src_node)
|
||||
node_pair = (parent_node, src_node)
|
||||
self.edge_costs.pop(node_pair)
|
||||
|
||||
def simplify_graph(self):
|
||||
if not self.simplify:
|
||||
return
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from abc import ABC, abstractmethod
|
||||
from torch.fx.node import Node
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
@ -56,7 +56,7 @@ class OperatorHandler(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec:
|
||||
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
"""
|
||||
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
||||
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||
|
@ -84,7 +84,9 @@ class OperatorHandler(ABC):
|
|||
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
|
||||
strategy.output_sharding_spec, input_spec)
|
||||
input_sharding_spec, input_spec)
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from dataclasses import dataclass
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Union, Tuple
|
||||
from torch.fx.node import Node
|
||||
from .constants import *
|
||||
|
||||
__all__ = ['ShardingStrategy', 'StrategiesVector']
|
||||
|
||||
|
@ -25,12 +26,15 @@ class ShardingStrategy:
|
|||
'''
|
||||
|
||||
name: str
|
||||
output_sharding_spec: ShardingSpec
|
||||
# TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
|
||||
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
|
||||
compute_cost: float = 0.
|
||||
communication_cost: float = 0.
|
||||
memory_cost: float = 0.
|
||||
resharding_costs: Dict[int, List[float]] = None
|
||||
input_shardings: ShardingSpec = None
|
||||
resharding_costs: Dict[Node, List[float]] = None
|
||||
# sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
|
||||
# Therefore, we could process them at the specific op(operator.getitem)
|
||||
input_shardings: List[ShardingSpec] = None
|
||||
|
||||
|
||||
class StrategiesVector(list):
|
||||
|
@ -46,8 +50,23 @@ class StrategiesVector(list):
|
|||
super().__init__()
|
||||
self.node = node
|
||||
# fetch its input and output nodes
|
||||
# TODO: placeholder input nodes
|
||||
self.predecessor_nodes = list(node._input_nodes.keys())
|
||||
self.successor_nodes = list(node.users.keys())
|
||||
|
||||
def check_merge(self):
|
||||
pass
|
||||
merge_label = False
|
||||
if self.node.op == 'call_module':
|
||||
target = self.node.target
|
||||
root_module = self.node.graph.owning_module
|
||||
submod = root_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
# merge elementwise module node into following nodes
|
||||
if submod_type in ELEMENTWISE_MODULE_OP:
|
||||
merge_label = True
|
||||
|
||||
if self.node.op == 'call_function':
|
||||
if self.node.target in ELEMENTWISE_FUNC_OP:
|
||||
merge_label = True
|
||||
|
||||
return merge_label
|
||||
|
|
|
@ -0,0 +1,355 @@
|
|||
from torch.fx import Graph, Node
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .conv_handler import ConvHandler
|
||||
from .constants import *
|
||||
from copy import deepcopy
|
||||
import math
|
||||
import torch
|
||||
import operator
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class StrategiesConstructor:
|
||||
|
||||
def __init__(self, graph, device_mesh, shape_consistency_manager, solver_options):
|
||||
self.graph = graph
|
||||
self.root_module = self.graph.owning_module
|
||||
self.nodes = list(graph.nodes)
|
||||
self.device_mesh = device_mesh
|
||||
self.leaf_strategies = []
|
||||
self.strategy_map = {}
|
||||
self.shape_consistency_manager = shape_consistency_manager
|
||||
self.solver_options = solver_options
|
||||
|
||||
def _generate_sharding_spec(self, node: Node, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
"""
|
||||
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
||||
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||
"""
|
||||
meta_tensor = node._meta_data
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=meta_tensor.shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
def _generate_resharding_costs(self, input_nodes, target_sharding_specs):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
||||
'''
|
||||
resharding_costs = {}
|
||||
for input_node, target_sharding_spec in zip(input_nodes, target_sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, target_sharding_spec)
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
||||
def remove_duplicated_strategy(self, strategies_vector):
|
||||
'''
|
||||
In build_strategies_and_cost method, we may produce some duplicated strategies.
|
||||
In this method, we will remove the duplicated strategies depending on the strategies name.
|
||||
'''
|
||||
name_checklist = []
|
||||
remove_list = []
|
||||
for strategy in strategies_vector:
|
||||
if strategy.name not in name_checklist:
|
||||
name_checklist.append(strategy.name)
|
||||
else:
|
||||
remove_list.append(strategy)
|
||||
for strategy in remove_list:
|
||||
strategies_vector.remove(strategy)
|
||||
|
||||
def build_strategies_and_cost(self):
|
||||
for node in self.nodes:
|
||||
strategies_vector = StrategiesVector(node)
|
||||
# placeholder node
|
||||
if node.op == 'placeholder':
|
||||
# For placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
|
||||
# fully replicate status, then strategies of following node will be treated equally due
|
||||
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
|
||||
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
|
||||
|
||||
if self.solver_options['fast_mode']:
|
||||
# create sharding strategy for placeholder
|
||||
name = 'Replica Placeholder'
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
sharding_strategy_placeholder = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
memory_cost=memory_cost)
|
||||
strategies_vector.append(sharding_strategy_placeholder)
|
||||
|
||||
# get_attr node
|
||||
if node.op == 'get_attr':
|
||||
# Same as placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
|
||||
# fully replicate status, then strategies of following node will be treated equally due
|
||||
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
|
||||
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
|
||||
if self.solver_options['fast_mode']:
|
||||
# create sharding strategy for get_attr
|
||||
name = 'Replica Attribute'
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
|
||||
strategies_vector.append(sharding_strategy_attribute)
|
||||
|
||||
# call_module node
|
||||
if node.op == 'call_module':
|
||||
|
||||
target = node.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
|
||||
# conv module
|
||||
if submod_type in CONV_MODULE_OP:
|
||||
# use ConvHandler to create sharding strategies for conv module node
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
conv_handler.register_strategy()
|
||||
|
||||
# linear module
|
||||
elif submod_type in LINEAR_MODULE_OP:
|
||||
# use DotHandler to create sharding strategies for linear module node
|
||||
dot_handler = DotHandler(node, self.device_mesh, strategies_vector, self.shape_consistency_manager)
|
||||
dot_handler.register_strategy()
|
||||
|
||||
# element-wise module
|
||||
elif submod_type in ELEMENTWISE_MODULE_OP:
|
||||
# create sharding strategy for element-wise module
|
||||
assert len(strategies_vector.predecessor_nodes
|
||||
) == 1, f'Temporally, we just support single input element-wise op.'
|
||||
input_node = strategies_vector.predecessor_nodes[0]
|
||||
# For element-wise module, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise module.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec,
|
||||
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_node] = [
|
||||
cost if cost == 0 else math.inf for cost in resharding_costs[input_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# other module
|
||||
else:
|
||||
raise RuntimeError(f'{submod_type} module is NOT supported now.')
|
||||
|
||||
# call_function node
|
||||
if node.op == 'call_function':
|
||||
target = node.target
|
||||
# conv function
|
||||
if target in CONV_FUNC_OP:
|
||||
# use ConvHandler to create sharding strategies for conv node
|
||||
# TODO: the operator_handler does NOT support function node processing now.
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
conv_handler.register_strategy()
|
||||
|
||||
# linear function
|
||||
elif target in LINEAR_FUNC_OP:
|
||||
# use DotHandler to create sharding strategies for linear node
|
||||
# TODO: the operator_handler does NOT support function node processing now.
|
||||
linear_handler = DotHandler(node, self.device_mesh, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
linear_handler.register_strategy()
|
||||
|
||||
# element-wise function
|
||||
elif target in ELEMENTWISE_FUNC_OP:
|
||||
# TODO: integrate element-wise func and module together
|
||||
# create sharding strategy for element-wise function
|
||||
assert len(strategies_vector.predecessor_nodes
|
||||
) == 1, f'Temporally, we just support single input element-wise op.'
|
||||
input_node = strategies_vector.predecessor_nodes[0]
|
||||
# For element-wise function, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise function.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec,
|
||||
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_node] = [
|
||||
0 if cost == 0 else math.inf for cost in resharding_costs[input_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# torch.var_mean
|
||||
elif target == torch.var_mean:
|
||||
dim = node.kwargs['dim']
|
||||
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||
for strategy in input_tensor_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec,
|
||||
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
entire_shape_input = input_sharding_spec.entire_shape
|
||||
dim_partition_dict_input = input_sharding_spec.dim_partition_dict
|
||||
name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})'
|
||||
if dim in dim_partition_dict_input:
|
||||
# We need to make the action dimension in replicate status
|
||||
dim_partition_dict_for_input = deepcopy(dim_partition_dict_input)
|
||||
dim_partition_dict_for_input.pop(dim)
|
||||
new_input_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_input,
|
||||
dim_partition_dict=dim_partition_dict_for_input)
|
||||
entire_shape_output = deepcopy(entire_shape_input)
|
||||
entire_shape_output.pop(dim)
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_input)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[new_input_sharding_spec])
|
||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[new_input_sharding_spec])
|
||||
|
||||
else:
|
||||
entire_shape_output = deepcopy(entire_shape_input)
|
||||
entire_shape_output.pop(dim)
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_input)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partion_dict=dim_partition_dict_input)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# operator.getitem
|
||||
elif target == operator.getitem:
|
||||
index = node.args[1]
|
||||
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||
for strategy in input_tensor_node.strategies_vector:
|
||||
input_sharding_spec = input_tensor_node.output_sharding_spec[index]
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
||||
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_tensor_node] = [
|
||||
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_tensor_node.output_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# other function
|
||||
else:
|
||||
raise RuntimeError(f'{target} function is NOT supported now.')
|
||||
|
||||
# output node
|
||||
if node.op == 'output':
|
||||
if self.solver_options['fast_mode']:
|
||||
# create sharding strategy for output
|
||||
name = 'Replica Output'
|
||||
input_nodes = strategies_vector.predecessor_nodes
|
||||
input_sharding_specs = []
|
||||
for input_node in input_nodes:
|
||||
dim_partition_dict_for_input = {}
|
||||
entire_shape = input_node._meta_data.shape
|
||||
sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape,
|
||||
dim_partition_dict=dim_partition_dict_for_input)
|
||||
input_sharding_specs.append(sharding_spec)
|
||||
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = input_sharding_specs
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
input_sharding_specs)
|
||||
sharding_strategy_attribute = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs)
|
||||
strategies_vector.append(sharding_strategy_attribute)
|
||||
|
||||
self.remove_duplicated_strategy(strategies_vector)
|
||||
setattr(node, 'strategies_vector', strategies_vector)
|
||||
self.leaf_strategies.append(strategies_vector)
|
||||
self.strategy_map[node] = strategies_vector
|
|
@ -0,0 +1,97 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.conv1(x)
|
||||
x = x / 2
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_cost_graph():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
|
||||
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv1, 2), kwargs = {})
|
||||
# %relu : [#users=1] = call_module[target=relu](args = (%truediv,), kwargs = {})
|
||||
# return relu
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
solver_options = {'fast_mode': True}
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
# (x, mul): {(0, 0): 0}
|
||||
# (mul, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0}
|
||||
# (conv1, truediv): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): 0, (8, 0): inf, (9, 0): 0, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): inf, (5, 1): 0, (6, 1): 0, (7, 1): inf, (8, 1): 0, (9, 1): inf, (10, 1): 0, (11, 1): 0, (12, 1): 0, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): 0, (11, 2): 0, (12, 2): 0, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): 0, (11, 3): 0, (12, 3): 0, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): 0, (11, 4): 0, (12, 4): 0, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): 0, (11, 5): 0, (12, 5): 0, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): 0, (11, 7): 0, (12, 7): 0, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0, (9, 8): inf, (10, 8): 0, (11, 8): 0, (12, 8): 0, (13, 8): inf, (14, 8): 0}
|
||||
# (truediv, relu): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): inf, (5, 0): 0, (6, 0): 0, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): 0, (5, 1): inf, (6, 1): 0, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): 0, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): 0, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): 0, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): 0, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): 0, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0}
|
||||
# (relu, output): {(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002}
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
|
||||
# construct all node pairs
|
||||
all_node_pairs = []
|
||||
|
||||
for node in graph.nodes:
|
||||
if node.op == 'output':
|
||||
continue
|
||||
all_node_pairs.append((node, node.next))
|
||||
|
||||
for node_pair in all_node_pairs:
|
||||
assert node_pair in cost_graph.edge_costs
|
||||
|
||||
# construct merged node pairs
|
||||
merged_node_pairs = []
|
||||
node_list = list(graph.nodes)
|
||||
|
||||
# add (x, conv) and (conv, output) into check node pairs
|
||||
merged_node_pairs.append((node_list[0], node_list[2]))
|
||||
merged_node_pairs.append((node_list[2], node_list[-1]))
|
||||
# (x, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0}
|
||||
# (conv1, output): {(0, 0): inf, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf}
|
||||
cost_graph.simplify_graph()
|
||||
for node_pair in all_node_pairs:
|
||||
if node_pair in merged_node_pairs:
|
||||
assert node_pair in cost_graph.edge_costs
|
||||
else:
|
||||
assert node_pair not in cost_graph.edge_costs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cost_graph()
|
|
@ -0,0 +1,98 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_strategies_constructor():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
||||
# return conv
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
solver_options = {'fast_mode': True}
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
|
||||
assert strategies_constructor.leaf_strategies == []
|
||||
assert strategies_constructor.strategy_map == {}
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
# check leaf_strategies
|
||||
|
||||
# In fast mode, placeholder node only has replica strategy.
|
||||
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
|
||||
|
||||
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
||||
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]'
|
||||
|
||||
# Third node is conv.
|
||||
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
|
||||
for strategy in strategies_constructor.leaf_strategies[2]:
|
||||
conv_check_list.remove(strategy.name)
|
||||
assert len(conv_check_list) == 0
|
||||
|
||||
# In fast mode, output node only has replica strategy.
|
||||
assert strategies_constructor.leaf_strategies[3][0].name == 'Replica Output'
|
||||
|
||||
# check strategy_map
|
||||
|
||||
nodes = [node for node in graph.nodes]
|
||||
# In fast mode, placeholder node only has replica strategy.
|
||||
x = nodes[0]
|
||||
assert strategies_constructor.strategy_map[x][0].name == 'Replica Placeholder'
|
||||
|
||||
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
||||
mul = nodes[1]
|
||||
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]'
|
||||
|
||||
# Third node is conv.
|
||||
conv = nodes[2]
|
||||
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
|
||||
for strategy in strategies_constructor.strategy_map[conv]:
|
||||
conv_check_list.remove(strategy.name)
|
||||
assert len(conv_check_list) == 0
|
||||
|
||||
# In fast mode, output node only has replica strategy.
|
||||
output = nodes[3]
|
||||
assert strategies_constructor.strategy_map[output][0].name == 'Replica Output'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_strategies_constructor()
|
Loading…
Reference in New Issue