mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] support fucntion in operator handler (#1529)
parent
44c866a3e3
commit
1a3599410d
|
@ -9,7 +9,7 @@ __all__ = ['ConvHandler']
|
|||
|
||||
class ConvHandler(OperatorHandler):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
||||
An OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -67,7 +67,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
|
@ -106,7 +106,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
|
@ -145,7 +145,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
|
@ -184,7 +184,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
|
@ -223,7 +223,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
|
@ -261,7 +261,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
|
@ -301,7 +301,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
|
@ -340,7 +340,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
|
||||
|
@ -380,7 +380,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
|
|
|
@ -15,7 +15,7 @@ __all__ = ['OperatorHandler']
|
|||
|
||||
class OperatorHandler(ABC):
|
||||
'''
|
||||
The OperatorHandler is an abstract class used to generate every possible strategies for a operator node.
|
||||
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
|
||||
|
||||
Argument:
|
||||
input_node(Node): the input node in node argument list.
|
||||
|
@ -43,6 +43,10 @@ class OperatorHandler(ABC):
|
|||
named_parameters = list(module.named_parameters(recurse=False))
|
||||
# convert named parameters from list to dict
|
||||
named_parameters = {k: v for k, v in named_parameters}
|
||||
elif self.node.op == 'call_function':
|
||||
module = None
|
||||
parameters = list(self.node.args)[1]
|
||||
named_parameters = {'weight': parameters._meta_data}
|
||||
else:
|
||||
module = None
|
||||
named_parameters = None
|
||||
|
|
|
@ -27,7 +27,13 @@ class StrategiesConstructor:
|
|||
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
||||
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||
"""
|
||||
meta_tensor = node._meta_data
|
||||
if hasattr(node, '_meta_data'):
|
||||
meta_tensor = node._meta_data
|
||||
elif isinstance(node, torch.Tensor):
|
||||
meta_tensor = node
|
||||
else:
|
||||
raise RuntimeError(f'We cannot generate sharding spec for {type(node)} type.')
|
||||
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=meta_tensor.shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
|
|
Loading…
Reference in New Issue