[autoparallel] support fucntion in operator handler (#1529)

pull/1583/head
YuliangLiu0306 2022-09-07 11:18:41 +08:00 committed by GitHub
parent 44c866a3e3
commit 1a3599410d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 12 deletions

View File

@ -9,7 +9,7 @@ __all__ = ['ConvHandler']
class ConvHandler(OperatorHandler): class ConvHandler(OperatorHandler):
""" """
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication. An OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -67,7 +67,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
@ -106,7 +106,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
@ -145,7 +145,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
@ -184,7 +184,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
@ -223,7 +223,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
@ -261,7 +261,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
@ -301,7 +301,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
@ -340,7 +340,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]) bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
@ -380,7 +380,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]

View File

@ -15,7 +15,7 @@ __all__ = ['OperatorHandler']
class OperatorHandler(ABC): class OperatorHandler(ABC):
''' '''
The OperatorHandler is an abstract class used to generate every possible strategies for a operator node. The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
Argument: Argument:
input_node(Node): the input node in node argument list. input_node(Node): the input node in node argument list.
@ -43,6 +43,10 @@ class OperatorHandler(ABC):
named_parameters = list(module.named_parameters(recurse=False)) named_parameters = list(module.named_parameters(recurse=False))
# convert named parameters from list to dict # convert named parameters from list to dict
named_parameters = {k: v for k, v in named_parameters} named_parameters = {k: v for k, v in named_parameters}
elif self.node.op == 'call_function':
module = None
parameters = list(self.node.args)[1]
named_parameters = {'weight': parameters._meta_data}
else: else:
module = None module = None
named_parameters = None named_parameters = None

View File

@ -27,7 +27,13 @@ class StrategiesConstructor:
Generate the sharding spec of the tensor based on the given dim_partition_dict Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding. where the key is the tensor dimension and the value is the mesh dimension for sharding.
""" """
meta_tensor = node._meta_data if hasattr(node, '_meta_data'):
meta_tensor = node._meta_data
elif isinstance(node, torch.Tensor):
meta_tensor = node
else:
raise RuntimeError(f'We cannot generate sharding spec for {type(node)} type.')
sharding_spec = ShardingSpec(device_mesh=self.device_mesh, sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=meta_tensor.shape, entire_shape=meta_tensor.shape,
dim_partition_dict=dim_partition_dict) dim_partition_dict=dim_partition_dict)