[autoparallel] support fucntion in operator handler (#1529)

pull/1583/head
YuliangLiu0306 2022-09-07 11:18:41 +08:00 committed by GitHub
parent 44c866a3e3
commit 1a3599410d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 12 deletions

View File

@ -9,7 +9,7 @@ __all__ = ['ConvHandler']
class ConvHandler(OperatorHandler):
"""
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
An OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
"""
def __init__(self, *args, **kwargs):
@ -67,7 +67,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
@ -106,7 +106,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
@ -145,7 +145,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
@ -184,7 +184,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
@ -223,7 +223,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
@ -261,7 +261,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
@ -301,7 +301,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
@ -340,7 +340,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
@ -380,7 +380,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]

View File

@ -15,7 +15,7 @@ __all__ = ['OperatorHandler']
class OperatorHandler(ABC):
'''
The OperatorHandler is an abstract class used to generate every possible strategies for a operator node.
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
Argument:
input_node(Node): the input node in node argument list.
@ -43,6 +43,10 @@ class OperatorHandler(ABC):
named_parameters = list(module.named_parameters(recurse=False))
# convert named parameters from list to dict
named_parameters = {k: v for k, v in named_parameters}
elif self.node.op == 'call_function':
module = None
parameters = list(self.node.args)[1]
named_parameters = {'weight': parameters._meta_data}
else:
module = None
named_parameters = None

View File

@ -27,7 +27,13 @@ class StrategiesConstructor:
Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
meta_tensor = node._meta_data
if hasattr(node, '_meta_data'):
meta_tensor = node._meta_data
elif isinstance(node, torch.Tensor):
meta_tensor = node
else:
raise RuntimeError(f'We cannot generate sharding spec for {type(node)} type.')
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=meta_tensor.shape,
dim_partition_dict=dim_partition_dict)