mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add elementwise handler (#1622)
* [autoparallel] add elementwise handler * polish code * polish code * reduce skipped strategies range * polish codepull/1638/head
parent
3a46215135
commit
c7ac0f4ab2
|
@ -4,5 +4,9 @@ from .conv_handler import ConvHandler
|
|||
from .batch_norm_handler import BatchNormHandler
|
||||
from .reshape_handler import ReshapeHandler
|
||||
from .bcast_op_handler import BcastOpHandler
|
||||
from .unary_elementwise_handler import UnaryElementwiseHandler
|
||||
|
||||
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler']
|
||||
__all__ = [
|
||||
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
|
||||
'UnaryElementwiseHandler'
|
||||
]
|
||||
|
|
|
@ -47,7 +47,10 @@ class OperatorHandler(ABC):
|
|||
elif self.node.op == 'call_function' and self.node.target not in NON_PARAM_FUNC_OP:
|
||||
module = None
|
||||
parameters = list(self.node.args)[1]
|
||||
named_parameters = {'weight': parameters._meta_data}
|
||||
if isinstance(parameters, Node):
|
||||
named_parameters = {'weight': parameters._meta_data}
|
||||
else:
|
||||
named_parameters = {}
|
||||
else:
|
||||
module = None
|
||||
named_parameters = None
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import warnings
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.constants import INFINITY_COST
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
import math
|
||||
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||
|
||||
__all__ = ['UnaryElementwiseHandler']
|
||||
|
||||
|
||||
class UnaryElementwiseHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of UnaryElementwiseOp.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.node.op == 'call_module':
|
||||
target = self.node.target
|
||||
submod = self.node.graph.owning_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
if submod_type == torch.nn.Dropout:
|
||||
print(f'predecessor nodes of dropout node are {self.predecessor_node}')
|
||||
input_nodes_len = 0
|
||||
for check_node in self.predecessor_node:
|
||||
if isinstance(check_node._meta_data, torch.Tensor):
|
||||
input_nodes_len += 1
|
||||
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op, node name is {self.node}, node args is {self.node.args}.'
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.input_node = self.predecessor_node[0]
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
return super()._generate_compute_cost(*args, **kwargs)
|
||||
|
||||
@exception_handler
|
||||
def register_strategy(self):
|
||||
# TODO: integrate element-wise func and module together
|
||||
# create sharding strategy for element-wise function
|
||||
|
||||
# For element-wise function, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise function.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in self.input_node.strategies_vector:
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
try:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
continue
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = self.output_data.numel()
|
||||
memory_cost = 0
|
||||
|
||||
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[self.input_node] = [
|
||||
0 if cost == 0 else INFINITY_COST for cost in resharding_costs[self.input_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
self.strategies_vector.append(sharding_strategy)
|
|
@ -14,6 +14,7 @@ import torch
|
|||
import operator
|
||||
from typing import Dict, List
|
||||
from ._utils import generate_sharding_spec, generate_resharding_costs
|
||||
import builtins
|
||||
|
||||
|
||||
class StrategiesConstructor:
|
||||
|
@ -63,7 +64,11 @@ class StrategiesConstructor:
|
|||
def build_strategies_and_cost(self):
|
||||
for node in self.nodes:
|
||||
strategies_vector = StrategiesVector(node)
|
||||
input_nodes_len = len(strategies_vector.predecessor_nodes)
|
||||
input_nodes_len = 0
|
||||
for check_node in strategies_vector.predecessor_nodes:
|
||||
if isinstance(check_node._meta_data, torch.Tensor):
|
||||
input_nodes_len += 1
|
||||
# input_nodes_len = len(strategies_vector.predecessor_nodes)
|
||||
# placeholder node
|
||||
if node.op == 'placeholder':
|
||||
# For placeholder nodes, if solver_options.fast is True, we just let them in
|
||||
|
@ -122,53 +127,12 @@ class StrategiesConstructor:
|
|||
|
||||
# element-wise module
|
||||
elif submod_type in ELEMENTWISE_MODULE_OP:
|
||||
# create sharding strategy for element-wise module
|
||||
assert len(strategies_vector.predecessor_nodes
|
||||
) == 1, f'Temporally, we just support single input element-wise op.'
|
||||
input_node = strategies_vector.predecessor_nodes[0]
|
||||
# For element-wise module, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise module.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec,
|
||||
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_node] = [
|
||||
cost if cost == 0 else math.inf for cost in resharding_costs[input_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
unary_elementwise_handler.register_strategy()
|
||||
|
||||
# BatchNormNd module
|
||||
elif submod_type in BATCHNORM_MODULE_OP:
|
||||
# bn1 call_module bn1 (conv1,)
|
||||
# print(node, node.op, node.target, node.args)
|
||||
# create sharding strategy for element-wise module
|
||||
# input_node = strategies_vector.predecessor_nodes[0]
|
||||
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector)
|
||||
norm_handler.register_strategy()
|
||||
# for strategy in norm_handler.strategies_vector:
|
||||
|
@ -181,8 +145,7 @@ class StrategiesConstructor:
|
|||
# e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension
|
||||
|
||||
# create sharding strategy for element-wise module
|
||||
assert len(strategies_vector.predecessor_nodes
|
||||
) == 1, f'Temporally, we just support single input element-wise op.'
|
||||
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.'
|
||||
input_node = strategies_vector.predecessor_nodes[0]
|
||||
# For element-wise module, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
|
@ -255,50 +218,15 @@ class StrategiesConstructor:
|
|||
|
||||
# element-wise function
|
||||
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
|
||||
# TODO: integrate element-wise func and module together
|
||||
# create sharding strategy for element-wise function
|
||||
assert len(strategies_vector.predecessor_nodes
|
||||
) == 1, f'Temporally, we just support single input element-wise op, node name is {node}.'
|
||||
input_node = strategies_vector.predecessor_nodes[0]
|
||||
# For element-wise function, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise function.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec,
|
||||
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_node] = [
|
||||
0 if cost == 0 else math.inf for cost in resharding_costs[input_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
if isinstance(node._meta_data, torch.Tensor):
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
unary_elementwise_handler.register_strategy()
|
||||
|
||||
# bcast op
|
||||
elif target in BCAST_FUNC_OP:
|
||||
bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
|
||||
bcast_op_handler.register_strategy()
|
||||
if isinstance(node._meta_data, torch.Tensor):
|
||||
bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
|
||||
bcast_op_handler.register_strategy()
|
||||
|
||||
# torch.var_mean
|
||||
elif target == torch.var_mean:
|
||||
|
@ -421,7 +349,10 @@ class StrategiesConstructor:
|
|||
elif method in RESHAPE_METHOD_OP:
|
||||
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
|
||||
reshape_handler.register_strategy()
|
||||
|
||||
# print(strategies_vector)
|
||||
# if len(strategies_vector) == 0:
|
||||
# print(node)
|
||||
# assert False
|
||||
else:
|
||||
raise RuntimeError(f'{method} function is NOT supported now.')
|
||||
|
||||
|
|
Loading…
Reference in New Issue