[autoparallel] update binary elementwise handler (#2451)

* [autoparallel] update binary elementwise handler

* polish
pull/2396/head
YuliangLiu0306 2023-01-12 09:35:10 +08:00 committed by GitHub
parent c9ec5190a0
commit 8221fd7485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 23 deletions

View File

@ -32,20 +32,32 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
return OperationDataType.ARG return OperationDataType.ARG
def _get_arg_value(idx): def _get_arg_value(idx):
non_tensor = False
if isinstance(self.node.args[idx], Node): if isinstance(self.node.args[idx], Node):
meta_data = self.node.args[idx]._meta_data meta_data = self.node.args[idx]._meta_data
# The meta_data of node type argument could also possibly be a non-tensor object.
if not isinstance(meta_data, torch.Tensor):
assert isinstance(meta_data, (int, float))
meta_data = torch.Tensor([meta_data]).to('meta')
non_tensor = True
else: else:
# this is in fact a real data like int 1 # this is in fact a real data like int 1
# but we can deem it as meta data # but we can deem it as meta data
# as it won't affect the strategy generation # as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float)) assert isinstance(self.node.args[idx], (int, float))
meta_data = torch.Tensor([self.node.args[idx]]).to('meta') meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
return meta_data non_tensor = True
input_meta_data = _get_arg_value(0) return meta_data, non_tensor
other_meta_data = _get_arg_value(1)
input_meta_data, non_tensor_input = _get_arg_value(0)
other_meta_data, non_tensor_other = _get_arg_value(1)
output_meta_data = self.node._meta_data output_meta_data = self.node._meta_data
# we need record op_data with non-tensor data in this list,
# and filter the non-tensor op_data in post_process.
self.non_tensor_list = []
# assert False
input_op_data = OperationData(name=str(self.node.args[0]), input_op_data = OperationData(name=str(self.node.args[0]),
type=_get_op_data_type(input_meta_data), type=_get_op_data_type(input_meta_data),
data=input_meta_data, data=input_meta_data,
@ -58,6 +70,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
type=OperationDataType.OUTPUT, type=OperationDataType.OUTPUT,
data=output_meta_data, data=output_meta_data,
logical_shape=bcast_shape) logical_shape=bcast_shape)
if non_tensor_input:
self.non_tensor_list.append(input_op_data)
if non_tensor_other:
self.non_tensor_list.append(other_op_data)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping return mapping
@ -73,9 +89,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
for op_name, op_data in op_data_mapping.items(): for op_name, op_data in op_data_mapping.items():
if not isinstance(op_data.data, torch.Tensor): if op_data in self.non_tensor_list:
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2) # remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
strategy.sharding_specs.pop(op_data) strategy.sharding_specs.pop(op_data)
else: else:
# convert the logical sharding spec to physical sharding spec if broadcast # convert the logical sharding spec to physical sharding spec if broadcast
# e.g. torch.rand(4, 4) + torch.rand(4) # e.g. torch.rand(4, 4) + torch.rand(4)

View File

@ -122,25 +122,41 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1] assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, port): class BEOpModelWithNodeConst(nn.Module):
def __init__(self, op):
super().__init__()
self.op = op
def forward(self, x1):
const = x1.dim()
out = self.op(x1, const)
return out
class BEOpModelWithIntConst(nn.Module):
def __init__(self, op, const):
super().__init__()
self.op = op
self.const = const
def forward(self, x1):
out = self.op(x1, self.const)
return out
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
class BinaryElementwiseOpModel(nn.Module):
def __init__(self, op, const):
super().__init__()
self.op = op
self.const = const
def forward(self, x1):
out = self.op(x1, self.const)
return out
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
model = BinaryElementwiseOpModel(op, other_dim).cuda() if model_cls == BEOpModelWithNodeConst:
model = model_cls(op).cuda()
else:
model = model_cls(op, other_dim).cuda()
x1 = torch.rand(4, 4).cuda() x1 = torch.rand(4, 4).cuda()
# the index of binary-elementwise node in computation graph # the index of binary-elementwise node in computation graph
node_index = 1 node_index = 1
@ -159,9 +175,14 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
tracer = ColoTracer() tracer = ColoTracer()
meta_args = {'x1': torch.rand(4, 4).to('meta')} meta_args = {'x1': torch.rand(4, 4).to('meta')}
graph = tracer.trace(model, meta_args=meta_args) graph = tracer.trace(model, meta_args=meta_args)
print(graph)
# assert False
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
op_node = list(graph.nodes)[1] if model_cls == BEOpModelWithNodeConst:
op_node = list(graph.nodes)[2]
else:
op_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(op_node) strategies_vector = StrategiesVector(op_node)
# build handler # build handler
@ -212,7 +233,7 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
@parameterize('other_dim', [1, 2]) @parameterize('other_dim', [1, 2])
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_binary_elementwise_handler(op, other_dim): def test_binary_elementwise_handler_with_tensor(op, other_dim):
world_size = 4 world_size = 4
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor, run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
op=op, op=op,
@ -220,8 +241,19 @@ def test_binary_elementwise_handler(op, other_dim):
world_size=world_size, world_size=world_size,
port=free_port()) port=free_port())
mp.spawn(run_func_tensor, nprocs=world_size) mp.spawn(run_func_tensor, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('op', [torch.add])
@parameterize('other_dim', [1, 2])
@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst])
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
world_size = 4
run_func_int = partial(check_binary_elementwise_handler_with_int, run_func_int = partial(check_binary_elementwise_handler_with_int,
op=op, op=op,
model_cls=model_cls,
other_dim=other_dim, other_dim=other_dim,
world_size=world_size, world_size=world_size,
port=free_port()) port=free_port())
@ -229,4 +261,5 @@ def test_binary_elementwise_handler(op, other_dim):
if __name__ == '__main__': if __name__ == '__main__':
test_binary_elementwise_handler() test_binary_elementwise_handler_with_tensor()
test_binary_elementwise_handler_with_int()

View File

@ -90,7 +90,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
solver_options = SolverOptions() solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost() strategies_constructor.build_strategies_and_cost()
target_node = list(graph.nodes)[node_index] target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies
][node_index]
if node_type == 'normal': if node_type == 'normal':
solution_len = len(strategies_constructor.leaf_strategies) solution_len = len(strategies_constructor.leaf_strategies)
solution = [0] * solution_len solution = [0] * solution_len
@ -112,7 +113,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
ret = solver.call_solver_serialized_args() ret = solver.call_solver_serialized_args()
solution = list(ret[0]) solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh) gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm) gm = runtime_apply_pass(gm)
gm.recompile() gm.recompile()