[autoparallel] added addbmm handler (#1751)

pull/1758/head
Frank Lee 2022-10-21 18:55:48 +08:00 committed by GitHub
parent 980ed21723
commit 262652c8bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 353 additions and 35 deletions

View File

@ -1,5 +1,5 @@
from .batch_norm_handler import BatchNormModuleHandler from .batch_norm_handler import BatchNormModuleHandler
from .bmm_handler import BMMFunctionHandler from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .layer_norm_handler import LayerNormModuleHandler from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler
@ -12,7 +12,8 @@ from .unary_elementwise_handler import UnaryElementwiseHandler
from .where_handler import WhereHandler from .where_handler import WhereHandler
__all__ = [ __all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'LayerNormModuleHandler', 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry' 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
'NormPoolingHandler', 'operator_registry'
] ]

View File

@ -1,33 +1,97 @@
from typing import Dict, List from typing import Dict, List, Union
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler']
def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
"""
This function is a helper function which extracts the common logic for both `bmm` and `addbmm`
node handler to reduce code redundancy.
"""
# input operand
physical_input_operand = OperationData(name=str(node.args[input_idx]),
type=OperationDataType.ARG,
data=node.args[input_idx]._meta_data)
# other operand
physical_other_operand = OperationData(name=str(node.args[other_idx]),
type=OperationDataType.ARG,
data=node.args[other_idx]._meta_data)
# output
physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if bias_idx is not None:
# bias physical shape
bias_logical_shape = node._meta_data.shape
physical_bias_operand = OperationData(name=str(node.args[bias_idx]),
type=OperationDataType.ARG,
data=node.args[bias_idx]._meta_data,
logical_shape=bias_logical_shape)
mapping['bias'] = physical_bias_operand
return mapping
@operator_registry.register(torch.bmm) @operator_registry.register(torch.bmm)
@operator_registry.register(torch.Tensor.bmm) @operator_registry.register(torch.Tensor.bmm)
class BMMFunctionHandler(NodeHandler): class BMMFunctionHandler(NodeHandler):
"""
This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch.
Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is
no logical-physical shape conversion in this handler.
"""
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
physical_input_operand = OperationData(name=str(self.node.args[0]), mapping = _get_data_mapping_for_bmm_op(node=self.node, input_idx=0, other_idx=1)
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.node.args[1]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
return mapping return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]: def get_strategy_generator(self) -> List[StrategyGenerator]:
generators = []
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
return generators return generators
@operator_registry.register(torch.addbmm)
@operator_registry.register(torch.Tensor.addbmm)
class AddBMMFunctionHandler(NodeHandler):
"""
This is a NodeHandler class which deals with the addition + batched matrix multiplication operation in PyTorch.
Such operations including `torch.addbmm` and `torch.Tensor.addbmm` require the two matmul tensor to be 3D. However, due to the
addition, logical-physical shape conversion is required for the bias term.
As the addbmm operation will reduce the batch dimension, the bias is maximum 2D.
"""
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
mapping = _get_data_mapping_for_bmm_op(node=self.node, input_idx=1, other_idx=2, bias_idx=0)
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
if 'bias' in op_data_mapping:
bias_op_data = op_data_mapping['bias']
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec = recover_sharding_spec_for_broadcast_shape(bias_sharding_spec, bias_logical_shape,
bias_physical_shape)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
return strategy

View File

@ -514,23 +514,60 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
A batched matrix multiplication can be viewed as A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j] [b, i, k] x [b, k, j] -> [b, i, j]
The bias term is considered to have a 2D logical shape.
""" """
def __init__(self, *args, **kwargs):
self.squeeze_batch_dim = False
super().__init__(*args, **kwargs)
def _pop_batch_dim_sharding_for_output(self, dim_partition_dict):
# remove partition dict for dim 0
dim_partition_dict['output'].pop(0, None)
# decrease the remaining dim index by 1
temp_dim_partition = {}
keys = list(dim_partition_dict['output'].keys())
for key in keys:
val = dim_partition_dict['output'].pop(key)
temp_dim_partition[key - 1] = val
dim_partition_dict['output'].update(temp_dim_partition)
def validate(self) -> bool: def validate(self) -> bool:
input_op_data = self.op_data['input'] input_op_data = self.op_data['input']
other_op_data = self.op_data['other'] other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 2 or other_op_data.data.dim() > 2 assert input_op_data.data.dim() == 3 or other_op_data.data.dim() == 3
if 'bias' in self.op_data:
bias_op_data = self.op_data['bias']
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
if self.op_data['output'].data.dim() == 2:
# addbmm will shrink the first batch dim
self.squeeze_batch_dim = True
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape) fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
self.op_data['output'].data.shape)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
@ignore_sharding_exception
def split_one_batch_dim(self, mesh_dim): def split_one_batch_dim(self, mesh_dim):
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
# get sharding_spec # get sharding_spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}} dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
print(sharding_spec_mapping)
# get communication actions # get communication actions
communication_action_mapping = {} communication_action_mapping = {}
if self.has_bias: if self.has_bias:
@ -543,6 +580,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict = { dim_partition_dict = {
@ -557,6 +595,8 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
0: [mesh_dim_0, mesh_dim_1] 0: [mesh_dim_0, mesh_dim_1]
} }
} }
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions # get communication actions
@ -572,22 +612,27 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
dim_partition_dict = { dim_partition_dict = {
"input": { "input": {
0: [mesh_dim_0], 0: [mesh_dim_0],
-2: [mesh_dim_1] 1: [mesh_dim_1]
}, },
"other": { "other": {
0: [mesh_dim_0] 0: [mesh_dim_0]
}, },
"bias": {}, "bias": {
0: [mesh_dim_1]
},
"output": { "output": {
0: [mesh_dim_0], 0: [mesh_dim_0],
-2: [mesh_dim_1] 1: [mesh_dim_1]
} }
} }
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions # get communication actions
@ -609,6 +654,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
dim_partition_dict = { dim_partition_dict = {
@ -617,16 +663,18 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
}, },
"other": { "other": {
0: [mesh_dim_0], 0: [mesh_dim_0],
-1: [mesh_dim_1] 2: [mesh_dim_1]
}, },
"bias": { "bias": {
-1: [mesh_dim_1] 1: [mesh_dim_1]
}, },
"output": { "output": {
0: [mesh_dim_0], 0: [mesh_dim_0],
-1: [mesh_dim_1] 2: [mesh_dim_1]
} }
} }
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions # get communication actions
@ -648,6 +696,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
dim_partition_dict = { dim_partition_dict = {
@ -664,6 +713,8 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
0: [mesh_dim_0], 0: [mesh_dim_0],
} }
} }
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions # get communication actions

View File

@ -4,7 +4,6 @@ from functools import reduce
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import torch import torch
from torch.fx import Node from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
@ -15,11 +14,9 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
ShardingStrategy, ShardingStrategy,
TrainCycleItem, TrainCycleItem,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx import Node
class StrategyGenerator(ABC): class StrategyGenerator(ABC):

View File

@ -1,6 +1,8 @@
import torch
from enum import Enum, auto from enum import Enum, auto
from typing import List from typing import List
import torch
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape'] __all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape']
@ -56,6 +58,9 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
logical_num_dims = len(logical_shape) logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape) physical_num_dims = len(physical_shape)
assert logical_num_dims >= physical_num_dims, \
'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
# track the dim and its broadcasting type # track the dim and its broadcasting type
logical_dim_broadcast_info = {} logical_dim_broadcast_info = {}

View File

@ -1,4 +1,5 @@
import torch import torch
from ..registry import meta_patched_function from ..registry import meta_patched_function
@ -56,6 +57,16 @@ def torch_bmm(input, mat2, *, out=None):
return torch.empty(batch_size, n, p, device="meta") return torch.empty(batch_size, n, p, device="meta")
@meta_patched_function.register(torch.addbmm)
@meta_patched_function.register(torch.Tensor.addbmm)
def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
batch_size, n, m = mat1.shape
_, _, p = mat2.shape
return torch.empty(n, p, device="meta")
@meta_patched_function.register(torch.var_mean) @meta_patched_function.register(torch.var_mean)
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
assert out is None, 'saving to out is not supported yet' assert out is None, 'saving to out is not supported yet'

View File

@ -0,0 +1,189 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import AddBMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import parameterize
class AddBMMTensorMethodModule(nn.Module):
def forward(self, bias, x1, x2):
return bias.addbmm(x1, x2)
class AddBMMTorchFunctionModule(nn.Module):
def forward(self, bias, x1, x2):
return torch.addbmm(bias, x1, x2)
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
def test_2d_device_mesh(module, bias_shape):
model = module()
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
'bias': torch.rand(*bias_shape).to('meta'),
"x1": torch.rand(4, 8, 16).to('meta'),
'x2': torch.rand(4, 16, 8).to('meta')
})
print(graph)
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node)
# build handler
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "x1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 8, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 8, 16])
assert mapping['other'].name == "x2"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([4, 16, 8])
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size(bias_shape)
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([8, 8])
assert mapping['output'].name == "addbmm"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([8, 8])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# one batch dim
assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list
# two batch dim
assert 'Sb01 = Sb01 x Sb01' in strategy_name_list
# SbSi = SbSi x Sb
assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list
assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list
# SbSj = SbR x SbSj
assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list
assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list
# SbR = SbSk x SbSk
assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
def test_1d_device_mesh(module, bias_shape):
model = module()
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
'bias': torch.rand(*bias_shape).to('meta'),
"x1": torch.rand(4, 8, 16).to('meta'),
'x2': torch.rand(4, 16, 8).to('meta')
})
print(graph)
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (1, 4)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node)
# build handler
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "x1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 8, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 8, 16])
assert mapping['other'].name == "x2"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([4, 16, 8])
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size(bias_shape)
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([8, 8])
assert mapping['output'].name == "addbmm"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([8, 8])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
assert len(strategy_name_list) == 1
# one batch dim
assert 'Sb0 = Sb0 x Sb0' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if __name__ == '__main__':
test_1d_device_mesh()
# test_2d_device_mesh()

View File

@ -6,6 +6,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandle
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import parameterize
class BMMTensorMethodModule(nn.Module): class BMMTensorMethodModule(nn.Module):
@ -20,7 +21,7 @@ class BMMTorchFunctionModule(nn.Module):
return torch.bmm(x1, x2) return torch.bmm(x1, x2)
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) @parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_2d_device_mesh(module): def test_2d_device_mesh(module):
model = module() model = module()
@ -95,12 +96,13 @@ def test_2d_device_mesh(module):
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
# make sure the sharding matches across different operation data # make sure the sharding matches across different operation data
print(input_sharding_spec.sharding_sequence, output_sharding_spec.sharding_sequence)
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) @parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_1d_device_mesh(module): def test_1d_device_mesh(module):
model = module() model = module()
tracer = ColoTracer() tracer = ColoTracer()
@ -165,7 +167,5 @@ def test_1d_device_mesh(module):
if __name__ == '__main__': if __name__ == '__main__':
test_1d_device_mesh(BMMTensorMethodModule) test_1d_device_mesh()
test_1d_device_mesh(BMMTorchFunctionModule) test_2d_device_mesh()
test_2d_device_mesh(BMMTensorMethodModule)
test_2d_device_mesh(BMMTorchFunctionModule)