ColossalAI/colossalai/auto_parallel/solver/strategies_constructor.py

397 lines
24 KiB
Python

from torch.fx import Graph, Node
from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler
from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from .options import SolverOptions
from . import ShardingStrategy, StrategiesVector
from .op_handler import *
from .constants import *
from copy import deepcopy
import math
import torch
import operator
from typing import Dict, List
from ._utils import generate_sharding_spec, generate_resharding_costs
import builtins
class StrategiesConstructor:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.solver_options = solver_options
def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
'''
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)
def _is_bcast_matmul(self, node):
is_bcast_matmul = False
if node.target is torch.matmul and len(node.args) == 2:
lhs_data = node.args[0]._meta_data
rhs_data = node.args[1]._meta_data
if lhs_data.dim() >= 3 and rhs_data.dim() >= 3:
is_bcast_matmul = True
return is_bcast_matmul
def build_strategies_and_cost(self):
for node in self.nodes:
strategies_vector = StrategiesVector(node)
input_nodes_len = 0
for check_node in strategies_vector.predecessor_nodes:
if isinstance(check_node._meta_data, torch.Tensor):
input_nodes_len += 1
# input_nodes_len = len(strategies_vector.predecessor_nodes)
# placeholder node
if node.op == 'placeholder':
# For placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
if self.solver_options.fast:
# create sharding strategy for placeholder
name = 'Replica Placeholder'
dim_partition_dict = {}
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_placeholder = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_placeholder)
# get_attr node
if node.op == 'get_attr':
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
if self.solver_options.fast:
# create sharding strategy for get_attr
name = 'Replica Attribute'
dim_partition_dict = {}
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_attribute)
# call_module node
if node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
# conv module
if submod_type in CONV_MODULE_OP:
# use ConvHandler to create sharding strategies for conv module node
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
conv_handler.register_strategy()
# linear module
elif submod_type in LINEAR_MODULE_OP:
# use DotHandler to create sharding strategies for linear module node
dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
dot_handler.register_strategy()
# element-wise module
elif submod_type in ELEMENTWISE_MODULE_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
# BatchNormNd module
elif submod_type in BATCHNORM_MODULE_OP:
# create sharding strategy for element-wise module
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector)
norm_handler.register_strategy()
# for strategy in norm_handler.strategies_vector:
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
# assert False
# MaxPool module
elif submod_type in POOL_MODULE_OP:
# TODO: add sharding constraints on image dimension
# e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension
# create sharding strategy for element-wise module
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.'
input_node = strategies_vector.predecessor_nodes[0]
# For element-wise module, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise module.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# embedding module
elif submod_type in EMBEDDING_MODULE_OP:
embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector)
embedding_handler.register_strategy()
# layernorm module
elif submod_type in LAYERNORM_MODULE_OP:
layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector)
layernorm_handler.register_strategy()
# other module
else:
raise RuntimeError(f'{submod_type} module is NOT supported now.')
# call_function node
if node.op == 'call_function':
target = node.target
# conv function
if target in CONV_FUNC_OP:
# use ConvHandler to create sharding strategies for conv node
# TODO: the operator_handler does NOT support function node processing now.
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
conv_handler.register_strategy()
# linear function
elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node):
# use DotHandler to create sharding strategies for linear node
# TODO: the operator_handler does NOT support function node processing now.
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
linear_handler.register_strategy()
# reshape function
elif target in RESHAPE_FUNC_OP:
# use ReshapeHandler to create sharding strategies for rehsape node
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
reshape_handler.register_strategy()
# element-wise function
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
if isinstance(node._meta_data, torch.Tensor):
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
# bcast op
elif target in BCAST_FUNC_OP:
if isinstance(node._meta_data, torch.Tensor):
bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
bcast_op_handler.register_strategy()
# torch.var_mean
elif target == torch.var_mean:
dim = node.kwargs['dim']
input_tensor_node = strategies_vector.predecessor_nodes[0]
for strategy in input_tensor_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
entire_shape_input = input_sharding_spec.entire_shape
dim_partition_dict_input = input_sharding_spec.dim_partition_dict
name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})'
if dim in dim_partition_dict_input:
# We need to make the action dimension in replicate status
dim_partition_dict_for_input = deepcopy(dim_partition_dict_input)
dim_partition_dict_for_input.pop(dim)
new_input_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_input,
dim_partition_dict=dim_partition_dict_for_input)
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[new_input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[new_input_sharding_spec])
else:
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partion_dict=dim_partition_dict_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# operator.getitem
elif target == operator.getitem:
index = node.args[1]
input_tensor_node = strategies_vector.predecessor_nodes[0]
if isinstance(input_tensor_node, torch.Tensor):
for strategy in input_tensor_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(
name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_tensor_node.output_sharding_spec])
strategies_vector.append(sharding_strategy)
# torch.arange function
elif target == torch.arange:
name = f'FULLY REPLICATED ARANGE'
entire_shape_output = node._meta_data.shape
dim_partition_dict_for_output = {}
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
memory_cost = node._meta_data.numel()
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=0,
memory_cost=memory_cost)
strategies_vector.append(sharding_strategy)
# op list to be processed to support gpt2
elif target in (builtins.getattr, operator.le, torch.addmm, operator.pow, torch.where, torch.softmax,
torch.nn.functional.softmax, torch.pow, torch.tanh):
pass
# other function
else:
raise RuntimeError(f'{target} function is NOT supported now.')
# call_method node
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
if method in (torch.Tensor.size, torch.Tensor.contiguous):
pass
elif method in ELEMENTWISE_METHOD_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
elif method in RESHAPE_METHOD_OP:
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
reshape_handler.register_strategy()
# print(strategies_vector)
# if len(strategies_vector) == 0:
# print(node)
# assert False
else:
raise RuntimeError(f'{method} function is NOT supported now.')
# output node
if node.op == 'output':
if self.solver_options.fast:
# create sharding strategy for output
name = 'Replica Output'
input_nodes = strategies_vector.predecessor_nodes
input_sharding_specs = []
for input_node in input_nodes:
dim_partition_dict_for_input = {}
entire_shape = input_node._meta_data.shape
sharding_spec = ShardingSpec(self.device_mesh,
entire_shape,
dim_partition_dict=dim_partition_dict_for_input)
input_sharding_specs.append(sharding_spec)
dim_partition_dict = {}
output_sharding_spec = input_sharding_specs
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
input_sharding_specs)
# clear the resharding cost for the output node
# TODO: we may remove this in final version
for prev_node, resharding_cost_list in resharding_costs.items():
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
sharding_strategy_attribute = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=tuple(input_sharding_specs))
strategies_vector.append(sharding_strategy_attribute)
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector