mirror of https://github.com/hpcaitech/ColossalAI
[autoparellel]add strategies constructor (#1505)
* [autoparellel]add strategies constructor * remove duplicated strategies * polish code * adapt cost graph with StrategiesConstructor * polishpull/1522/head
parent
a0436a62ee
commit
3345c6d352
|
@ -0,0 +1,22 @@
|
||||||
|
import torch
|
||||||
|
import operator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP',
|
||||||
|
'LINEAR_FUNC_OP'
|
||||||
|
]
|
||||||
|
|
||||||
|
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||||
|
ELEMENTWISE_FUNC_OP = [
|
||||||
|
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
|
||||||
|
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
|
||||||
|
]
|
||||||
|
CONV_MODULE_OP = [
|
||||||
|
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||||
|
torch.nn.ConvTranspose3d
|
||||||
|
]
|
||||||
|
CONV_FUNC_OP = [
|
||||||
|
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
|
||||||
|
]
|
||||||
|
LINEAR_MODULE_OP = [torch.nn.Linear]
|
||||||
|
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
|
@ -494,3 +494,10 @@ class ConvHandler(OperatorHandler):
|
||||||
self.split_1d_parallel_on_in_channel(0, 1)
|
self.split_1d_parallel_on_in_channel(0, 1)
|
||||||
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
||||||
|
|
||||||
|
CONV_STRATEGIES_LIST = [
|
||||||
|
'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R',
|
||||||
|
'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1',
|
||||||
|
'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'
|
||||||
|
]
|
||||||
|
|
|
@ -39,11 +39,11 @@ class CostGraph:
|
||||||
dst_node = strategies_vector.node
|
dst_node = strategies_vector.node
|
||||||
for src_node in strategies_vector.predecessor_nodes:
|
for src_node in strategies_vector.predecessor_nodes:
|
||||||
node_pair = (src_node, dst_node)
|
node_pair = (src_node, dst_node)
|
||||||
src_index = strategies_vector.predecessor_nodes.index(src_node)
|
# src_index = strategies_vector.predecessor_nodes.index(src_node)
|
||||||
edge_cost = {}
|
edge_cost = {}
|
||||||
for i in range(len(strategies_vector)):
|
for i in range(len(strategies_vector)):
|
||||||
for j in range(len(src_node.stategy_vector)):
|
for j in range(len(src_node.strategies_vector)):
|
||||||
edge_cost[(i, j)] = strategies_vector[i].resharding_costs[src_index][j]
|
edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j]
|
||||||
self.edge_costs[node_pair] = edge_cost
|
self.edge_costs[node_pair] = edge_cost
|
||||||
# add parents and children attribute to node
|
# add parents and children attribute to node
|
||||||
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
|
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
|
||||||
|
@ -83,33 +83,19 @@ class CostGraph:
|
||||||
merge_map = {}
|
merge_map = {}
|
||||||
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
|
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
|
||||||
resharding_costs = strategy.resharding_costs
|
resharding_costs = strategy.resharding_costs
|
||||||
resharding_cost_for_src = resharding_costs[src_node_index]
|
resharding_cost_for_src = resharding_costs[src_node]
|
||||||
lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src))
|
lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src))
|
||||||
merge_map[dst_strate_index] = lowest_cost_index
|
merge_map[dst_strate_index] = lowest_cost_index
|
||||||
|
|
||||||
# extra_node_cost for dst node
|
# extra_node_cost for dst node
|
||||||
extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])]
|
self.extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])]
|
||||||
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
|
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
|
||||||
target_strate_index = merge_map[dst_strate_index]
|
target_strate_index = merge_map[dst_strate_index]
|
||||||
extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node_index][
|
self.extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node][
|
||||||
target_strate_index]
|
target_strate_index]
|
||||||
if src_node in extra_node_costs:
|
if src_node in self.extra_node_costs:
|
||||||
extra_node_costs[dst_node][dst_strate_index] += extra_node_costs[src_node][target_strate_index]
|
self.extra_node_costs[dst_node][dst_strate_index] += self.extra_node_costs[src_node][
|
||||||
|
target_strate_index]
|
||||||
# connect dst node and parents of src node
|
|
||||||
dst_node.parents.remove(src_node)
|
|
||||||
src_node.children.remove(dst_node)
|
|
||||||
node_pair_to_remove = [(src_node, dst_node)]
|
|
||||||
for parent_node in src_node.parents:
|
|
||||||
if parent_node not in dst_node.parents:
|
|
||||||
dst_node.parents.append(parent)
|
|
||||||
if dst_node not in parent_node.children:
|
|
||||||
parent_node.children.append(dst_node)
|
|
||||||
# remove src node from cost graph when src node has no consumer.
|
|
||||||
if len(src_node.children) == 0:
|
|
||||||
parent_node.children.remove(src_node)
|
|
||||||
node_pair = (parent_node, src_node)
|
|
||||||
self.edge_costs.pop(node_pair)
|
|
||||||
|
|
||||||
# add new node pair to cost graph
|
# add new node pair to cost graph
|
||||||
for parent_node in src_node.parents:
|
for parent_node in src_node.parents:
|
||||||
|
@ -121,9 +107,24 @@ class CostGraph:
|
||||||
for i in range(self.node_lens[dst_node]):
|
for i in range(self.node_lens[dst_node]):
|
||||||
for j in range(self.node_lens[parent_node]):
|
for j in range(self.node_lens[parent_node]):
|
||||||
src_strate_index = merge_map[i]
|
src_strate_index = merge_map[i]
|
||||||
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(j, src_strate_index)]
|
edge_cost[(j, i)] = self.edge_costs[old_node_pair][(j, src_strate_index)]
|
||||||
self.edge_costs[new_node_pair] = edge_cost
|
self.edge_costs[new_node_pair] = edge_cost
|
||||||
|
|
||||||
|
# connect dst node and parents of src node
|
||||||
|
dst_node.parents.remove(src_node)
|
||||||
|
src_node.children.remove(dst_node)
|
||||||
|
self.edge_costs.pop((src_node, dst_node))
|
||||||
|
for parent_node in src_node.parents:
|
||||||
|
if parent_node not in dst_node.parents:
|
||||||
|
dst_node.parents.append(parent_node)
|
||||||
|
if dst_node not in parent_node.children:
|
||||||
|
parent_node.children.append(dst_node)
|
||||||
|
# remove src node from cost graph when src node has no consumer.
|
||||||
|
if len(src_node.children) == 0:
|
||||||
|
parent_node.children.remove(src_node)
|
||||||
|
node_pair = (parent_node, src_node)
|
||||||
|
self.edge_costs.pop(node_pair)
|
||||||
|
|
||||||
def simplify_graph(self):
|
def simplify_graph(self):
|
||||||
if not self.simplify:
|
if not self.simplify:
|
||||||
return
|
return
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
@ -56,7 +56,7 @@ class OperatorHandler(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec:
|
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||||
"""
|
"""
|
||||||
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
||||||
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||||
|
@ -84,7 +84,9 @@ class OperatorHandler(ABC):
|
||||||
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
||||||
resharding_costs[input_node] = []
|
resharding_costs[input_node] = []
|
||||||
for strategy in input_node.strategies_vector:
|
for strategy in input_node.strategies_vector:
|
||||||
|
input_sharding_spec = strategy.output_sharding_spec
|
||||||
|
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
|
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
|
||||||
strategy.output_sharding_spec, input_spec)
|
input_sharding_spec, input_spec)
|
||||||
resharding_costs[input_node].append(resharding_cost)
|
resharding_costs[input_node].append(resharding_cost)
|
||||||
return resharding_costs
|
return resharding_costs
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Union, Tuple
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
from .constants import *
|
||||||
|
|
||||||
__all__ = ['ShardingStrategy', 'StrategiesVector']
|
__all__ = ['ShardingStrategy', 'StrategiesVector']
|
||||||
|
|
||||||
|
@ -25,12 +26,15 @@ class ShardingStrategy:
|
||||||
'''
|
'''
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
output_sharding_spec: ShardingSpec
|
# TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
|
||||||
|
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
|
||||||
compute_cost: float = 0.
|
compute_cost: float = 0.
|
||||||
communication_cost: float = 0.
|
communication_cost: float = 0.
|
||||||
memory_cost: float = 0.
|
memory_cost: float = 0.
|
||||||
resharding_costs: Dict[int, List[float]] = None
|
resharding_costs: Dict[Node, List[float]] = None
|
||||||
input_shardings: ShardingSpec = None
|
# sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
|
||||||
|
# Therefore, we could process them at the specific op(operator.getitem)
|
||||||
|
input_shardings: List[ShardingSpec] = None
|
||||||
|
|
||||||
|
|
||||||
class StrategiesVector(list):
|
class StrategiesVector(list):
|
||||||
|
@ -46,8 +50,23 @@ class StrategiesVector(list):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.node = node
|
self.node = node
|
||||||
# fetch its input and output nodes
|
# fetch its input and output nodes
|
||||||
|
# TODO: placeholder input nodes
|
||||||
self.predecessor_nodes = list(node._input_nodes.keys())
|
self.predecessor_nodes = list(node._input_nodes.keys())
|
||||||
self.successor_nodes = list(node.users.keys())
|
self.successor_nodes = list(node.users.keys())
|
||||||
|
|
||||||
def check_merge(self):
|
def check_merge(self):
|
||||||
pass
|
merge_label = False
|
||||||
|
if self.node.op == 'call_module':
|
||||||
|
target = self.node.target
|
||||||
|
root_module = self.node.graph.owning_module
|
||||||
|
submod = root_module.get_submodule(target)
|
||||||
|
submod_type = type(submod)
|
||||||
|
# merge elementwise module node into following nodes
|
||||||
|
if submod_type in ELEMENTWISE_MODULE_OP:
|
||||||
|
merge_label = True
|
||||||
|
|
||||||
|
if self.node.op == 'call_function':
|
||||||
|
if self.node.target in ELEMENTWISE_FUNC_OP:
|
||||||
|
merge_label = True
|
||||||
|
|
||||||
|
return merge_label
|
||||||
|
|
|
@ -0,0 +1,355 @@
|
||||||
|
from torch.fx import Graph, Node
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
|
from .conv_handler import ConvHandler
|
||||||
|
from .constants import *
|
||||||
|
from copy import deepcopy
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import operator
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
class StrategiesConstructor:
|
||||||
|
|
||||||
|
def __init__(self, graph, device_mesh, shape_consistency_manager, solver_options):
|
||||||
|
self.graph = graph
|
||||||
|
self.root_module = self.graph.owning_module
|
||||||
|
self.nodes = list(graph.nodes)
|
||||||
|
self.device_mesh = device_mesh
|
||||||
|
self.leaf_strategies = []
|
||||||
|
self.strategy_map = {}
|
||||||
|
self.shape_consistency_manager = shape_consistency_manager
|
||||||
|
self.solver_options = solver_options
|
||||||
|
|
||||||
|
def _generate_sharding_spec(self, node: Node, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||||
|
"""
|
||||||
|
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
||||||
|
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||||
|
"""
|
||||||
|
meta_tensor = node._meta_data
|
||||||
|
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||||
|
entire_shape=meta_tensor.shape,
|
||||||
|
dim_partition_dict=dim_partition_dict)
|
||||||
|
return sharding_spec
|
||||||
|
|
||||||
|
def _generate_resharding_costs(self, input_nodes, target_sharding_specs):
|
||||||
|
'''
|
||||||
|
Compute the resharding costs with this specific strategy.
|
||||||
|
|
||||||
|
Argument:
|
||||||
|
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
||||||
|
'''
|
||||||
|
resharding_costs = {}
|
||||||
|
for input_node, target_sharding_spec in zip(input_nodes, target_sharding_specs):
|
||||||
|
resharding_costs[input_node] = []
|
||||||
|
for strategy in input_node.strategies_vector:
|
||||||
|
input_sharding_spec = strategy.output_sharding_spec
|
||||||
|
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||||
|
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
|
||||||
|
input_sharding_spec, target_sharding_spec)
|
||||||
|
resharding_costs[input_node].append(resharding_cost)
|
||||||
|
return resharding_costs
|
||||||
|
|
||||||
|
def remove_duplicated_strategy(self, strategies_vector):
|
||||||
|
'''
|
||||||
|
In build_strategies_and_cost method, we may produce some duplicated strategies.
|
||||||
|
In this method, we will remove the duplicated strategies depending on the strategies name.
|
||||||
|
'''
|
||||||
|
name_checklist = []
|
||||||
|
remove_list = []
|
||||||
|
for strategy in strategies_vector:
|
||||||
|
if strategy.name not in name_checklist:
|
||||||
|
name_checklist.append(strategy.name)
|
||||||
|
else:
|
||||||
|
remove_list.append(strategy)
|
||||||
|
for strategy in remove_list:
|
||||||
|
strategies_vector.remove(strategy)
|
||||||
|
|
||||||
|
def build_strategies_and_cost(self):
|
||||||
|
for node in self.nodes:
|
||||||
|
strategies_vector = StrategiesVector(node)
|
||||||
|
# placeholder node
|
||||||
|
if node.op == 'placeholder':
|
||||||
|
# For placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
|
||||||
|
# fully replicate status, then strategies of following node will be treated equally due
|
||||||
|
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||||
|
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
|
||||||
|
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
|
||||||
|
|
||||||
|
if self.solver_options['fast_mode']:
|
||||||
|
# create sharding strategy for placeholder
|
||||||
|
name = 'Replica Placeholder'
|
||||||
|
dim_partition_dict = {}
|
||||||
|
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||||
|
# TODO: use meta_info_prop to profile memory cost
|
||||||
|
memory_cost = 0
|
||||||
|
sharding_strategy_placeholder = ShardingStrategy(name,
|
||||||
|
output_sharding_spec,
|
||||||
|
memory_cost=memory_cost)
|
||||||
|
strategies_vector.append(sharding_strategy_placeholder)
|
||||||
|
|
||||||
|
# get_attr node
|
||||||
|
if node.op == 'get_attr':
|
||||||
|
# Same as placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
|
||||||
|
# fully replicate status, then strategies of following node will be treated equally due
|
||||||
|
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||||
|
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
|
||||||
|
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
|
||||||
|
if self.solver_options['fast_mode']:
|
||||||
|
# create sharding strategy for get_attr
|
||||||
|
name = 'Replica Attribute'
|
||||||
|
dim_partition_dict = {}
|
||||||
|
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||||
|
# TODO: use meta_info_prop to profile memory cost
|
||||||
|
memory_cost = 0
|
||||||
|
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
|
||||||
|
strategies_vector.append(sharding_strategy_attribute)
|
||||||
|
|
||||||
|
# call_module node
|
||||||
|
if node.op == 'call_module':
|
||||||
|
|
||||||
|
target = node.target
|
||||||
|
submod = self.root_module.get_submodule(target)
|
||||||
|
submod_type = type(submod)
|
||||||
|
|
||||||
|
# conv module
|
||||||
|
if submod_type in CONV_MODULE_OP:
|
||||||
|
# use ConvHandler to create sharding strategies for conv module node
|
||||||
|
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
|
||||||
|
self.shape_consistency_manager)
|
||||||
|
conv_handler.register_strategy()
|
||||||
|
|
||||||
|
# linear module
|
||||||
|
elif submod_type in LINEAR_MODULE_OP:
|
||||||
|
# use DotHandler to create sharding strategies for linear module node
|
||||||
|
dot_handler = DotHandler(node, self.device_mesh, strategies_vector, self.shape_consistency_manager)
|
||||||
|
dot_handler.register_strategy()
|
||||||
|
|
||||||
|
# element-wise module
|
||||||
|
elif submod_type in ELEMENTWISE_MODULE_OP:
|
||||||
|
# create sharding strategy for element-wise module
|
||||||
|
assert len(strategies_vector.predecessor_nodes
|
||||||
|
) == 1, f'Temporally, we just support single input element-wise op.'
|
||||||
|
input_node = strategies_vector.predecessor_nodes[0]
|
||||||
|
# For element-wise module, we keep the sharding spec of output node same as
|
||||||
|
# the input. Therefore, the different strategies of input node with same
|
||||||
|
# output sharding spec will generate same strategy for element-wise module.
|
||||||
|
sharding_spec_checklist = []
|
||||||
|
for strategy in input_node.strategies_vector:
|
||||||
|
# It looks a little bit confusing, the input of the processing node
|
||||||
|
# is the output of the input_node.
|
||||||
|
input_sharding_spec = strategy.output_sharding_spec
|
||||||
|
assert isinstance(input_sharding_spec,
|
||||||
|
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||||
|
if input_sharding_spec in sharding_spec_checklist:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sharding_spec_checklist.append(input_sharding_spec)
|
||||||
|
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||||
|
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||||
|
|
||||||
|
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||||
|
|
||||||
|
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||||
|
compute_cost = node._meta_data.numel()
|
||||||
|
memory_cost = 0
|
||||||
|
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
|
[input_sharding_spec])
|
||||||
|
|
||||||
|
# to prevent the resharding happening, set their resharding cost to inf.
|
||||||
|
resharding_costs[input_node] = [
|
||||||
|
cost if cost == 0 else math.inf for cost in resharding_costs[input_node]
|
||||||
|
]
|
||||||
|
sharding_strategy = ShardingStrategy(name,
|
||||||
|
output_sharding_spec,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=[input_sharding_spec])
|
||||||
|
strategies_vector.append(sharding_strategy)
|
||||||
|
|
||||||
|
# other module
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'{submod_type} module is NOT supported now.')
|
||||||
|
|
||||||
|
# call_function node
|
||||||
|
if node.op == 'call_function':
|
||||||
|
target = node.target
|
||||||
|
# conv function
|
||||||
|
if target in CONV_FUNC_OP:
|
||||||
|
# use ConvHandler to create sharding strategies for conv node
|
||||||
|
# TODO: the operator_handler does NOT support function node processing now.
|
||||||
|
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
|
||||||
|
self.shape_consistency_manager)
|
||||||
|
conv_handler.register_strategy()
|
||||||
|
|
||||||
|
# linear function
|
||||||
|
elif target in LINEAR_FUNC_OP:
|
||||||
|
# use DotHandler to create sharding strategies for linear node
|
||||||
|
# TODO: the operator_handler does NOT support function node processing now.
|
||||||
|
linear_handler = DotHandler(node, self.device_mesh, strategies_vector,
|
||||||
|
self.shape_consistency_manager)
|
||||||
|
linear_handler.register_strategy()
|
||||||
|
|
||||||
|
# element-wise function
|
||||||
|
elif target in ELEMENTWISE_FUNC_OP:
|
||||||
|
# TODO: integrate element-wise func and module together
|
||||||
|
# create sharding strategy for element-wise function
|
||||||
|
assert len(strategies_vector.predecessor_nodes
|
||||||
|
) == 1, f'Temporally, we just support single input element-wise op.'
|
||||||
|
input_node = strategies_vector.predecessor_nodes[0]
|
||||||
|
# For element-wise function, we keep the sharding spec of output node same as
|
||||||
|
# the input. Therefore, the different strategies of input node with same
|
||||||
|
# output sharding spec will generate same strategy for element-wise function.
|
||||||
|
sharding_spec_checklist = []
|
||||||
|
for strategy in input_node.strategies_vector:
|
||||||
|
# It looks a little bit confusing, the input of the processing node
|
||||||
|
# is the output of the input_node.
|
||||||
|
input_sharding_spec = strategy.output_sharding_spec
|
||||||
|
assert isinstance(input_sharding_spec,
|
||||||
|
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||||
|
if input_sharding_spec in sharding_spec_checklist:
|
||||||
|
continue
|
||||||
|
sharding_spec_checklist.append(input_sharding_spec)
|
||||||
|
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||||
|
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||||
|
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||||
|
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||||
|
compute_cost = node._meta_data.numel()
|
||||||
|
memory_cost = 0
|
||||||
|
|
||||||
|
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
|
[input_sharding_spec])
|
||||||
|
|
||||||
|
# to prevent the resharding happening, set their resharding cost to inf.
|
||||||
|
resharding_costs[input_node] = [
|
||||||
|
0 if cost == 0 else math.inf for cost in resharding_costs[input_node]
|
||||||
|
]
|
||||||
|
sharding_strategy = ShardingStrategy(name,
|
||||||
|
output_sharding_spec,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=[input_sharding_spec])
|
||||||
|
strategies_vector.append(sharding_strategy)
|
||||||
|
|
||||||
|
# torch.var_mean
|
||||||
|
elif target == torch.var_mean:
|
||||||
|
dim = node.kwargs['dim']
|
||||||
|
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||||
|
for strategy in input_tensor_node.strategies_vector:
|
||||||
|
input_sharding_spec = strategy.output_sharding_spec
|
||||||
|
assert isinstance(input_sharding_spec,
|
||||||
|
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||||
|
entire_shape_input = input_sharding_spec.entire_shape
|
||||||
|
dim_partition_dict_input = input_sharding_spec.dim_partition_dict
|
||||||
|
name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})'
|
||||||
|
if dim in dim_partition_dict_input:
|
||||||
|
# We need to make the action dimension in replicate status
|
||||||
|
dim_partition_dict_for_input = deepcopy(dim_partition_dict_input)
|
||||||
|
dim_partition_dict_for_input.pop(dim)
|
||||||
|
new_input_sharding_spec = ShardingSpec(self.device_mesh,
|
||||||
|
entire_shape_input,
|
||||||
|
dim_partition_dict=dim_partition_dict_for_input)
|
||||||
|
entire_shape_output = deepcopy(entire_shape_input)
|
||||||
|
entire_shape_output.pop(dim)
|
||||||
|
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input)
|
||||||
|
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||||
|
entire_shape_output,
|
||||||
|
dim_partition_dict=dim_partition_dict_for_input)
|
||||||
|
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||||
|
compute_cost = 0
|
||||||
|
memory_cost = 0
|
||||||
|
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
|
[new_input_sharding_spec])
|
||||||
|
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=[new_input_sharding_spec])
|
||||||
|
|
||||||
|
else:
|
||||||
|
entire_shape_output = deepcopy(entire_shape_input)
|
||||||
|
entire_shape_output.pop(dim)
|
||||||
|
dim_partition_dict_for_output = deepcopy(dim_partition_dict_input)
|
||||||
|
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||||
|
entire_shape_output,
|
||||||
|
dim_partion_dict=dim_partition_dict_input)
|
||||||
|
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||||
|
compute_cost = 0
|
||||||
|
memory_cost = 0
|
||||||
|
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
|
[input_sharding_spec])
|
||||||
|
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=[input_sharding_spec])
|
||||||
|
|
||||||
|
strategies_vector.append(sharding_strategy)
|
||||||
|
|
||||||
|
# operator.getitem
|
||||||
|
elif target == operator.getitem:
|
||||||
|
index = node.args[1]
|
||||||
|
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||||
|
for strategy in input_tensor_node.strategies_vector:
|
||||||
|
input_sharding_spec = input_tensor_node.output_sharding_spec[index]
|
||||||
|
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
||||||
|
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||||
|
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
|
||||||
|
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||||
|
entire_shape_output,
|
||||||
|
dim_partition_dict=dim_partition_dict_for_output)
|
||||||
|
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||||
|
compute_cost = 0
|
||||||
|
memory_cost = 0
|
||||||
|
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
|
[input_sharding_spec])
|
||||||
|
# to prevent the resharding happening, set their resharding cost to inf.
|
||||||
|
resharding_costs[input_tensor_node] = [
|
||||||
|
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
|
||||||
|
]
|
||||||
|
sharding_strategy = ShardingStrategy(name,
|
||||||
|
output_sharding_spec,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=[input_tensor_node.output_sharding_spec])
|
||||||
|
strategies_vector.append(sharding_strategy)
|
||||||
|
|
||||||
|
# other function
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'{target} function is NOT supported now.')
|
||||||
|
|
||||||
|
# output node
|
||||||
|
if node.op == 'output':
|
||||||
|
if self.solver_options['fast_mode']:
|
||||||
|
# create sharding strategy for output
|
||||||
|
name = 'Replica Output'
|
||||||
|
input_nodes = strategies_vector.predecessor_nodes
|
||||||
|
input_sharding_specs = []
|
||||||
|
for input_node in input_nodes:
|
||||||
|
dim_partition_dict_for_input = {}
|
||||||
|
entire_shape = input_node._meta_data.shape
|
||||||
|
sharding_spec = ShardingSpec(self.device_mesh,
|
||||||
|
entire_shape,
|
||||||
|
dim_partition_dict=dim_partition_dict_for_input)
|
||||||
|
input_sharding_specs.append(sharding_spec)
|
||||||
|
|
||||||
|
dim_partition_dict = {}
|
||||||
|
output_sharding_spec = input_sharding_specs
|
||||||
|
# TODO: use meta_info_prop to profile memory cost
|
||||||
|
memory_cost = 0
|
||||||
|
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
|
input_sharding_specs)
|
||||||
|
sharding_strategy_attribute = ShardingStrategy(name,
|
||||||
|
output_sharding_spec,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs)
|
||||||
|
strategies_vector.append(sharding_strategy_attribute)
|
||||||
|
|
||||||
|
self.remove_duplicated_strategy(strategies_vector)
|
||||||
|
setattr(node, 'strategies_vector', strategies_vector)
|
||||||
|
self.leaf_strategies.append(strategies_vector)
|
||||||
|
self.strategy_map[node] = strategies_vector
|
|
@ -0,0 +1,97 @@
|
||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
import torch.nn as nn
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from colossalai.fx.proxy import ColoProxy
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
|
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||||
|
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
class ConvModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, c_in, c_out):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x * 2
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = x / 2
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_cost_graph():
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
entire_shape = torch.Size((4, 16, 64, 64))
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
model = ConvModel(16, 32)
|
||||||
|
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||||
|
|
||||||
|
# graph():
|
||||||
|
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||||
|
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||||
|
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
|
||||||
|
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv1, 2), kwargs = {})
|
||||||
|
# %relu : [#users=1] = call_module[target=relu](args = (%truediv,), kwargs = {})
|
||||||
|
# return relu
|
||||||
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
solver_options = {'fast_mode': True}
|
||||||
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||||
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
# (x, mul): {(0, 0): 0}
|
||||||
|
# (mul, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0}
|
||||||
|
# (conv1, truediv): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): 0, (8, 0): inf, (9, 0): 0, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): inf, (5, 1): 0, (6, 1): 0, (7, 1): inf, (8, 1): 0, (9, 1): inf, (10, 1): 0, (11, 1): 0, (12, 1): 0, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): 0, (11, 2): 0, (12, 2): 0, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): 0, (11, 3): 0, (12, 3): 0, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): 0, (11, 4): 0, (12, 4): 0, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): 0, (11, 5): 0, (12, 5): 0, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): 0, (11, 7): 0, (12, 7): 0, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0, (9, 8): inf, (10, 8): 0, (11, 8): 0, (12, 8): 0, (13, 8): inf, (14, 8): 0}
|
||||||
|
# (truediv, relu): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): inf, (5, 0): 0, (6, 0): 0, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): 0, (5, 1): inf, (6, 1): 0, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): 0, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): 0, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): 0, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): 0, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): 0, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0}
|
||||||
|
# (relu, output): {(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002}
|
||||||
|
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||||
|
|
||||||
|
# construct all node pairs
|
||||||
|
all_node_pairs = []
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.op == 'output':
|
||||||
|
continue
|
||||||
|
all_node_pairs.append((node, node.next))
|
||||||
|
|
||||||
|
for node_pair in all_node_pairs:
|
||||||
|
assert node_pair in cost_graph.edge_costs
|
||||||
|
|
||||||
|
# construct merged node pairs
|
||||||
|
merged_node_pairs = []
|
||||||
|
node_list = list(graph.nodes)
|
||||||
|
|
||||||
|
# add (x, conv) and (conv, output) into check node pairs
|
||||||
|
merged_node_pairs.append((node_list[0], node_list[2]))
|
||||||
|
merged_node_pairs.append((node_list[2], node_list[-1]))
|
||||||
|
# (x, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0}
|
||||||
|
# (conv1, output): {(0, 0): inf, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf}
|
||||||
|
cost_graph.simplify_graph()
|
||||||
|
for node_pair in all_node_pairs:
|
||||||
|
if node_pair in merged_node_pairs:
|
||||||
|
assert node_pair in cost_graph.edge_costs
|
||||||
|
else:
|
||||||
|
assert node_pair not in cost_graph.edge_costs
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_cost_graph()
|
|
@ -0,0 +1,98 @@
|
||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
import torch.nn as nn
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from colossalai.fx.proxy import ColoProxy
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
|
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
class ConvModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, c_in, c_out):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x * 2
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategies_constructor():
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
entire_shape = torch.Size((4, 16, 64, 64))
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
model = ConvModel(16, 32)
|
||||||
|
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||||
|
# graph():
|
||||||
|
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||||
|
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||||
|
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
||||||
|
# return conv
|
||||||
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
solver_options = {'fast_mode': True}
|
||||||
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||||
|
|
||||||
|
assert strategies_constructor.leaf_strategies == []
|
||||||
|
assert strategies_constructor.strategy_map == {}
|
||||||
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
# check leaf_strategies
|
||||||
|
|
||||||
|
# In fast mode, placeholder node only has replica strategy.
|
||||||
|
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
|
||||||
|
|
||||||
|
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
||||||
|
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]'
|
||||||
|
|
||||||
|
# Third node is conv.
|
||||||
|
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
|
||||||
|
for strategy in strategies_constructor.leaf_strategies[2]:
|
||||||
|
conv_check_list.remove(strategy.name)
|
||||||
|
assert len(conv_check_list) == 0
|
||||||
|
|
||||||
|
# In fast mode, output node only has replica strategy.
|
||||||
|
assert strategies_constructor.leaf_strategies[3][0].name == 'Replica Output'
|
||||||
|
|
||||||
|
# check strategy_map
|
||||||
|
|
||||||
|
nodes = [node for node in graph.nodes]
|
||||||
|
# In fast mode, placeholder node only has replica strategy.
|
||||||
|
x = nodes[0]
|
||||||
|
assert strategies_constructor.strategy_map[x][0].name == 'Replica Placeholder'
|
||||||
|
|
||||||
|
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
||||||
|
mul = nodes[1]
|
||||||
|
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]'
|
||||||
|
|
||||||
|
# Third node is conv.
|
||||||
|
conv = nodes[2]
|
||||||
|
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
|
||||||
|
for strategy in strategies_constructor.strategy_map[conv]:
|
||||||
|
conv_check_list.remove(strategy.name)
|
||||||
|
assert len(conv_check_list) == 0
|
||||||
|
|
||||||
|
# In fast mode, output node only has replica strategy.
|
||||||
|
output = nodes[3]
|
||||||
|
assert strategies_constructor.strategy_map[output][0].name == 'Replica Output'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_strategies_constructor()
|
Loading…
Reference in New Issue