[autoparallel] update binary elementwise handler (#2451)

* [autoparallel] update binary elementwise handler

* polish
pull/2396/head
YuliangLiu0306 2023-01-12 09:35:10 +08:00 committed by GitHub
parent c9ec5190a0
commit 8221fd7485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 23 deletions

View File

@ -32,20 +32,32 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
return OperationDataType.ARG
def _get_arg_value(idx):
non_tensor = False
if isinstance(self.node.args[idx], Node):
meta_data = self.node.args[idx]._meta_data
# The meta_data of node type argument could also possibly be a non-tensor object.
if not isinstance(meta_data, torch.Tensor):
assert isinstance(meta_data, (int, float))
meta_data = torch.Tensor([meta_data]).to('meta')
non_tensor = True
else:
# this is in fact a real data like int 1
# but we can deem it as meta data
# as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float))
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
return meta_data
non_tensor = True
input_meta_data = _get_arg_value(0)
other_meta_data = _get_arg_value(1)
return meta_data, non_tensor
input_meta_data, non_tensor_input = _get_arg_value(0)
other_meta_data, non_tensor_other = _get_arg_value(1)
output_meta_data = self.node._meta_data
# we need record op_data with non-tensor data in this list,
# and filter the non-tensor op_data in post_process.
self.non_tensor_list = []
# assert False
input_op_data = OperationData(name=str(self.node.args[0]),
type=_get_op_data_type(input_meta_data),
data=input_meta_data,
@ -58,6 +70,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=bcast_shape)
if non_tensor_input:
self.non_tensor_list.append(input_op_data)
if non_tensor_other:
self.non_tensor_list.append(other_op_data)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping
@ -73,9 +89,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
op_data_mapping = self.get_operation_data_mapping()
for op_name, op_data in op_data_mapping.items():
if not isinstance(op_data.data, torch.Tensor):
if op_data in self.non_tensor_list:
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
strategy.sharding_specs.pop(op_data)
else:
# convert the logical sharding spec to physical sharding spec if broadcast
# e.g. torch.rand(4, 4) + torch.rand(4)

View File

@ -122,25 +122,41 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, port):
class BEOpModelWithNodeConst(nn.Module):
def __init__(self, op):
super().__init__()
self.op = op
def forward(self, x1):
const = x1.dim()
out = self.op(x1, const)
return out
class BEOpModelWithIntConst(nn.Module):
def __init__(self, op, const):
super().__init__()
self.op = op
self.const = const
def forward(self, x1):
out = self.op(x1, self.const)
return out
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
class BinaryElementwiseOpModel(nn.Module):
def __init__(self, op, const):
super().__init__()
self.op = op
self.const = const
def forward(self, x1):
out = self.op(x1, self.const)
return out
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
model = BinaryElementwiseOpModel(op, other_dim).cuda()
if model_cls == BEOpModelWithNodeConst:
model = model_cls(op).cuda()
else:
model = model_cls(op, other_dim).cuda()
x1 = torch.rand(4, 4).cuda()
# the index of binary-elementwise node in computation graph
node_index = 1
@ -159,9 +175,14 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
tracer = ColoTracer()
meta_args = {'x1': torch.rand(4, 4).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
print(graph)
# assert False
gm = ColoGraphModule(model, graph)
op_node = list(graph.nodes)[1]
if model_cls == BEOpModelWithNodeConst:
op_node = list(graph.nodes)[2]
else:
op_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(op_node)
# build handler
@ -212,7 +233,7 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
@parameterize('other_dim', [1, 2])
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler(op, other_dim):
def test_binary_elementwise_handler_with_tensor(op, other_dim):
world_size = 4
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
op=op,
@ -220,8 +241,19 @@ def test_binary_elementwise_handler(op, other_dim):
world_size=world_size,
port=free_port())
mp.spawn(run_func_tensor, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('op', [torch.add])
@parameterize('other_dim', [1, 2])
@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst])
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
world_size = 4
run_func_int = partial(check_binary_elementwise_handler_with_int,
op=op,
model_cls=model_cls,
other_dim=other_dim,
world_size=world_size,
port=free_port())
@ -229,4 +261,5 @@ def test_binary_elementwise_handler(op, other_dim):
if __name__ == '__main__':
test_binary_elementwise_handler()
test_binary_elementwise_handler_with_tensor()
test_binary_elementwise_handler_with_int()

View File

@ -90,7 +90,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
target_node = list(graph.nodes)[node_index]
target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies
][node_index]
if node_type == 'normal':
solution_len = len(strategies_constructor.leaf_strategies)
solution = [0] * solution_len
@ -112,7 +113,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh)
gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm)
gm.recompile()