mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added binary elementwise node handler (#1758)
* [autoparallel] added binary elementwise node handler * polish codepull/1759/head
parent
d2fc067231
commit
f9a613d660
@ -0,0 +1,86 @@
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
|
||||
from ..constants import BCAST_FUNC_OP
|
||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['BinaryElementwiseHandler']
|
||||
|
||||
|
||||
@operator_registry.register(BCAST_FUNC_OP)
|
||||
class BinaryElementwiseHandler(NodeHandler):
|
||||
"""
|
||||
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
||||
operands and broadcasting occurs such as torch.add.
|
||||
"""
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
bcast_shape = self.node._meta_data.shape
|
||||
|
||||
def _get_op_data_type(tensor):
|
||||
if isinstance(tensor, torch.nn.parameter.Parameter):
|
||||
return OperationDataType.PARAM
|
||||
else:
|
||||
return OperationDataType.ARG
|
||||
|
||||
def _get_arg_value(idx):
|
||||
if isinstance(self.node.args[idx], Node):
|
||||
meta_data = self.node.args[idx]._meta_data
|
||||
else:
|
||||
# this is in fact a real data like int 1
|
||||
# but we can deem it as meta data
|
||||
# as it won't affect the strategy generation
|
||||
assert isinstance(self.node.args[idx], (int, float))
|
||||
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
|
||||
return meta_data
|
||||
|
||||
input_meta_data = _get_arg_value(0)
|
||||
other_meta_data = _get_arg_value(1)
|
||||
output_meta_data = self.node._meta_data
|
||||
|
||||
input_op_data = OperationData(name=str(self.node.args[0]),
|
||||
type=_get_op_data_type(input_meta_data),
|
||||
data=input_meta_data,
|
||||
logical_shape=bcast_shape)
|
||||
other_op_data = OperationData(name=str(self.node.args[1]),
|
||||
type=_get_op_data_type(other_meta_data),
|
||||
data=other_meta_data,
|
||||
logical_shape=bcast_shape)
|
||||
output_op_data = OperationData(name=str(self.node),
|
||||
type=OperationDataType.OUTPUT,
|
||||
data=output_meta_data,
|
||||
logical_shape=bcast_shape)
|
||||
|
||||
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
|
||||
return mapping
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(BinaryElementwiseStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
return generators
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||
# convert bias from its logical sharding spec to its physical sharding spec
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
|
||||
for op_name, op_data in op_data_mapping.items():
|
||||
if not isinstance(op_data.data, torch.Tensor):
|
||||
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
|
||||
strategy.sharding_specs.pop(op_data)
|
||||
else:
|
||||
# convert the logical sharding spec to physical sharding spec if broadcast
|
||||
# e.g. torch.rand(4, 4) + torch.rand(4)
|
||||
physical_shape = op_data.data.shape
|
||||
logical_shape = op_data.logical_shape
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
|
||||
sharding_spec = recover_sharding_spec_for_broadcast_shape(sharding_spec, logical_shape, physical_shape)
|
||||
strategy.sharding_specs[op_data] = sharding_spec
|
||||
return strategy
|
@ -0,0 +1,111 @@
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
ignore_sharding_exception,
|
||||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
||||
__all__ = ['BinaryElementwiseStrategyGenerator']
|
||||
|
||||
|
||||
class BinaryElementwiseStrategyGenerator(StrategyGenerator):
|
||||
"""
|
||||
An BinaryElementwiseStrategyGenerator is a node handler which deals with elementwise operations
|
||||
which have two operands and broadcasting occurs such as torch.add.
|
||||
|
||||
The logical shape for this operation will be `input <op> other`.
|
||||
"""
|
||||
|
||||
def validate(self) -> bool:
|
||||
assert len(self.op_data) == 3, \
|
||||
f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}'
|
||||
for name, op_data in self.op_data.items():
|
||||
if not isinstance(op_data.data, (torch.Tensor, int, float)):
|
||||
raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.')
|
||||
|
||||
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||
shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
|
||||
|
||||
# since elementwise ops are not compute-intensive,
|
||||
# we approximate the backward compute cost
|
||||
# to be twice the fwd compute cost
|
||||
fwd_compute_cost = reduce(operator.mul, shape)
|
||||
bwd_compute_cost = fwd_compute_cost * 2
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
strategy.compute_cost = compute_cost
|
||||
|
||||
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||
# all input, output and outputs have the same shape
|
||||
shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
|
||||
|
||||
# compute fwd memory cost in bytes
|
||||
# as the elementwise ops are not memory-intensive
|
||||
# we approximate the fwd memroy cost to be the output
|
||||
# and the backward memory cost to be grad of input and other
|
||||
input_bytes = self._compute_size_in_bytes(strategy, 'input')
|
||||
other_bytes = self._compute_size_in_bytes(strategy, 'other')
|
||||
output_bytes = self._compute_size_in_bytes(strategy, 'output')
|
||||
fwd_memory_cost = MemoryCost(activation=output_bytes)
|
||||
bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes)
|
||||
total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes)
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
@ignore_sharding_exception
|
||||
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||
# we check for the output logical shape to get the number of dimensions
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['output'].logical_shape)
|
||||
|
||||
# enumerate all the 2D sharding cases
|
||||
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
||||
dim_partition_list.extend(sharding_list_2d)
|
||||
|
||||
# enumerate all the 1D sharding cases
|
||||
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
||||
dim_partition_list.extend(sharding_list_1d_on_dim_0)
|
||||
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
||||
dim_partition_list.extend(sharding_list_1d_on_dim_1)
|
||||
|
||||
# add empty dict for fully replicated case
|
||||
dim_partition_list.append({})
|
||||
|
||||
# sharding strategy bookkeeping
|
||||
strategy_list = []
|
||||
|
||||
# convert these dim partition dict to sharding strategy
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
dim_partition_dict_mapping = dict(input=dim_partition_dict,
|
||||
other=dim_partition_dict,
|
||||
output=dim_partition_dict)
|
||||
|
||||
try:
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
communication_action_mapping = {}
|
||||
|
||||
# get name
|
||||
sharding_seq = sharding_spec_mapping['input'].sharding_sequence
|
||||
name = f'{sharding_seq} = {sharding_seq} <binary-elementwise-op> {sharding_seq}'
|
||||
sharding_strategy = self.get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
strategy_list.append(sharding_strategy)
|
||||
except ShardingSpecException:
|
||||
continue
|
||||
return strategy_list
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = self.enumerate_all_possible_output(0, 1)
|
||||
return strategy_list
|
@ -0,0 +1,173 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
@parameterize('op', [torch.add])
|
||||
@parameterize('other_dim', [1, 2])
|
||||
def test_binary_elementwise_handler_with_tensor(op, other_dim):
|
||||
|
||||
class BinaryElementwiseOpModel(nn.Module):
|
||||
|
||||
def __init__(self, op):
|
||||
super().__init__()
|
||||
self.op = op
|
||||
|
||||
def forward(self, x1, x2):
|
||||
out = self.op(x1, x2)
|
||||
return out
|
||||
|
||||
model = BinaryElementwiseOpModel(op)
|
||||
tracer = ColoTracer()
|
||||
|
||||
meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
print(graph)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
op_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(op_node)
|
||||
|
||||
# build handler
|
||||
handler = BinaryElementwiseHandler(node=op_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
for name, op_data in mapping.items():
|
||||
op_data: OperationData
|
||||
# make sure they have valid values
|
||||
assert op_data.logical_shape is not None
|
||||
assert op_data.data is not None
|
||||
|
||||
assert mapping['input'].name == "x1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 4])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 4])
|
||||
|
||||
assert mapping['other'].name == "x2"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].data.shape == torch.Size([4] * other_dim)
|
||||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([4, 4])
|
||||
|
||||
assert mapping['output'].name == str(op_node)
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 4])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
assert mapping['output'].logical_shape == torch.Size([4, 4])
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
|
||||
# one strategy will be converted to different physical sharding spec
|
||||
assert len(strategy_name_list) == 9
|
||||
|
||||
# check if the sharding strategy is correct
|
||||
assert '[S0, S1] = [S0, S1] <binary-elementwise-op> [S0, S1]' in strategy_name_list
|
||||
assert '[S1, S0] = [S1, S0] <binary-elementwise-op> [S1, S0]' in strategy_name_list
|
||||
assert '[S01, R] = [S01, R] <binary-elementwise-op> [S01, R]' in strategy_name_list
|
||||
assert '[R, S01] = [R, S01] <binary-elementwise-op> [R, S01]' in strategy_name_list
|
||||
assert '[S0, R] = [S0, R] <binary-elementwise-op> [S0, R]' in strategy_name_list
|
||||
assert '[R, S0] = [R, S0] <binary-elementwise-op> [R, S0]' in strategy_name_list
|
||||
assert '[S1, R] = [S1, R] <binary-elementwise-op> [S1, R]' in strategy_name_list
|
||||
assert '[R, S1] = [R, S1] <binary-elementwise-op> [R, S1]' in strategy_name_list
|
||||
assert '[R, R] = [R, R] <binary-elementwise-op> [R, R]' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node))
|
||||
|
||||
# make sure the sharding spec is the same for input and output
|
||||
assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence
|
||||
|
||||
# since the dim of the other can change, we make sure at least its last dim sharding is the same
|
||||
if len(other_sharding_spec.sharding_sequence) == 2:
|
||||
assert input_sharding_spec.sharding_sequence == other_sharding_spec.sharding_sequence
|
||||
elif len(other_sharding_spec.sharding_sequence) == 1:
|
||||
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@parameterize('op', [torch.add])
|
||||
@parameterize('other', [1, 2])
|
||||
def test_binary_elementwise_handler_with_int(op, other):
|
||||
|
||||
class BinaryElementwiseOpModel(nn.Module):
|
||||
|
||||
def __init__(self, op, const):
|
||||
super().__init__()
|
||||
self.op = op
|
||||
self.const = const
|
||||
|
||||
def forward(self, x1):
|
||||
out = self.op(x1, self.const)
|
||||
return out
|
||||
|
||||
model = BinaryElementwiseOpModel(op, other)
|
||||
tracer = ColoTracer()
|
||||
|
||||
meta_args = {'x1': torch.rand(4, 4).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
print(graph)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
op_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(op_node)
|
||||
|
||||
# build handler
|
||||
handler = BinaryElementwiseHandler(node=op_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
assert mapping['input'].name == "x1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 4])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 4])
|
||||
|
||||
assert mapping['output'].name == str(op_node)
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 4])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
assert mapping['output'].logical_shape == torch.Size([4, 4])
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
|
||||
# one strategy will be converted to different physical sharding spec
|
||||
assert len(strategy_name_list) == 9
|
||||
|
||||
# check if the sharding strategy is correct
|
||||
assert '[S0, S1] = [S0, S1] <binary-elementwise-op> [S0, S1]' in strategy_name_list
|
||||
assert '[S1, S0] = [S1, S0] <binary-elementwise-op> [S1, S0]' in strategy_name_list
|
||||
assert '[S01, R] = [S01, R] <binary-elementwise-op> [S01, R]' in strategy_name_list
|
||||
assert '[R, S01] = [R, S01] <binary-elementwise-op> [R, S01]' in strategy_name_list
|
||||
assert '[S0, R] = [S0, R] <binary-elementwise-op> [S0, R]' in strategy_name_list
|
||||
assert '[R, S0] = [R, S0] <binary-elementwise-op> [R, S0]' in strategy_name_list
|
||||
assert '[S1, R] = [S1, R] <binary-elementwise-op> [S1, R]' in strategy_name_list
|
||||
assert '[R, S1] = [R, S1] <binary-elementwise-op> [R, S1]' in strategy_name_list
|
||||
assert '[R, R] = [R, R] <binary-elementwise-op> [R, R]' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node))
|
||||
|
||||
# make sure the sharding spec is the same for input and output
|
||||
assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_binary_elementwise_handler_with_tensor()
|
||||
test_binary_elementwise_handler_with_int()
|
Loading…
Reference in new issue