[autoparallel] added node handler for bmm (#1655)

pull/1664/head
Frank Lee 2022-09-28 11:32:16 +08:00 committed by GitHub
parent 095854477f
commit 3a4d6f63a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 207 additions and 24 deletions

View File

@ -2,11 +2,11 @@ import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2, BatchedMatMulStrategyGenerator
from typing import List, Dict
from .registry import operator_registry
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
@operator_registry.register(torch.nn.Linear)
@ -133,3 +133,30 @@ class LinearFunctionHandler(NodeHandler):
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
return strategy
@operator_registry.register(torch.bmm)
@operator_registry.register(torch.Tensor.bmm)
class BMMFunctionHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.node.args[1]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
generators = []
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
return generators

View File

@ -483,6 +483,9 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 2 or other_op_data.data.dim() > 2
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape)
def split_one_batch_dim(self):
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 1:
@ -552,7 +555,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
},
"bias": {},
"output": {
0: mesh_dim_0,
0: [mesh_dim_0],
-2: [mesh_dim_1]
}
}
@ -635,25 +638,27 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# can be None as it is only for 1D device mesh
strategy = self.split_one_batch_dim()
if strategy:
# only for 1D device mesh
strategy_list.append(strategy)
else:
# for 2D device mesh
# split batch dim of two inputs and the i dim of the first tensor
# SbSi = SbSi x Sb
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
# split batch dim of two inputs and the i dim of the first tensor
# SbSi = SbSi x Sb
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
# split batch dim of two inputs and the j of the second tensor
# SbSj = Sb x SbSj
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
# split batch dim of two inputs and the j of the second tensor
# SbSj = Sb x SbSj
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
# split batch dim of two inputs and the k dim of two inputs
# Sb = SbSk x SbSk, need to all-reduce by k dim
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
# split batch dim of two inputs and the k dim of two inputs
# Sb = SbSk x SbSk, need to all-reduce by k dim
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
# split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1))
strategy_list.append(self.split_two_batch_dim(1, 0))
# split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1))
strategy_list.append(self.split_two_batch_dim(1, 0))
return strategy_list

View File

@ -49,11 +49,12 @@ class StrategyGenerator_V2(ABC):
"""
results = {}
for op_data_name, dim_partition_dict in mapping.items():
op_data = self.op_data[op_data_name]
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=op_data.logical_shape,
dim_partition_dict=dim_partition_dict)
results[op_data_name] = sharding_spec
if op_data_name in self.op_data:
op_data = self.op_data[op_data_name]
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=op_data.logical_shape,
dim_partition_dict=dim_partition_dict)
results[op_data_name] = sharding_spec
return results
def replace_op_name_with_op_data(self, mapping: Dict[str, Any]):

View File

@ -0,0 +1,150 @@
import pytest
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import BMMFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
class BMMTensorMethodModule(nn.Module):
def forward(self, x1, x2):
return x1.bmm(x2)
class BMMTorchFunctionModule(nn.Module):
def forward(self, x1, x2):
return torch.bmm(x1, x2)
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_2d_device_mesh(module):
model = module()
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
"x1": torch.rand(4, 8, 16).to('meta'),
'x2': torch.rand(4, 16, 8).to('meta')
})
print(graph)
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
linear_mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_mod_node)
# build handler
handler = BMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "x1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 8, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 8, 16])
assert mapping['other'].name == "x2"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([4, 16, 8])
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
assert mapping['output'].name == "bmm"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# one batch dim
assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list
# two batch dim
assert 'Sb01 = Sb01 x Sb01' in strategy_name_list
# SbSi = SbSi x Sb
assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list
assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list
# SbSj = SbR x SbSj
assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list
assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list
# SbR = SbSk x SbSk
assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_1d_device_mesh(module):
model = module()
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
"x1": torch.rand(4, 8, 16).to('meta'),
'x2': torch.rand(4, 16, 8).to('meta')
})
print(graph)
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (1, 4)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
linear_mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_mod_node)
# build handler
handler = BMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "x1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 8, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 8, 16])
assert mapping['other'].name == "x2"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([4, 16, 8])
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
assert mapping['output'].name == "bmm"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
assert len(strategy_name_list) == 1
# one batch dim
assert 'Sb0 = Sb0 x Sb0' in strategy_name_list
if __name__ == '__main__':
test_1d_device_mesh(BMMTensorMethodModule)
test_1d_device_mesh(BMMTorchFunctionModule)
test_2d_device_mesh(BMMTensorMethodModule)
test_2d_device_mesh(BMMTorchFunctionModule)