mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] support fucntion in operator handler (#1529)
parent
44c866a3e3
commit
1a3599410d
|
@ -9,7 +9,7 @@ __all__ = ['ConvHandler']
|
||||||
|
|
||||||
class ConvHandler(OperatorHandler):
|
class ConvHandler(OperatorHandler):
|
||||||
"""
|
"""
|
||||||
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
An OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -67,7 +67,7 @@ class ConvHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||||
|
@ -106,7 +106,7 @@ class ConvHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||||
|
@ -145,7 +145,7 @@ class ConvHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||||
|
@ -184,7 +184,7 @@ class ConvHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0]
|
bs = self.input_data.shape[0]
|
||||||
|
@ -223,7 +223,7 @@ class ConvHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0]
|
bs = self.input_data.shape[0]
|
||||||
|
@ -261,7 +261,7 @@ class ConvHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0]
|
bs = self.input_data.shape[0]
|
||||||
|
@ -301,7 +301,7 @@ class ConvHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0]
|
bs = self.input_data.shape[0]
|
||||||
|
@ -340,7 +340,7 @@ class ConvHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
|
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
|
||||||
|
@ -380,7 +380,7 @@ class ConvHandler(OperatorHandler):
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0]
|
bs = self.input_data.shape[0]
|
||||||
|
|
|
@ -15,7 +15,7 @@ __all__ = ['OperatorHandler']
|
||||||
|
|
||||||
class OperatorHandler(ABC):
|
class OperatorHandler(ABC):
|
||||||
'''
|
'''
|
||||||
The OperatorHandler is an abstract class used to generate every possible strategies for a operator node.
|
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
input_node(Node): the input node in node argument list.
|
input_node(Node): the input node in node argument list.
|
||||||
|
@ -43,6 +43,10 @@ class OperatorHandler(ABC):
|
||||||
named_parameters = list(module.named_parameters(recurse=False))
|
named_parameters = list(module.named_parameters(recurse=False))
|
||||||
# convert named parameters from list to dict
|
# convert named parameters from list to dict
|
||||||
named_parameters = {k: v for k, v in named_parameters}
|
named_parameters = {k: v for k, v in named_parameters}
|
||||||
|
elif self.node.op == 'call_function':
|
||||||
|
module = None
|
||||||
|
parameters = list(self.node.args)[1]
|
||||||
|
named_parameters = {'weight': parameters._meta_data}
|
||||||
else:
|
else:
|
||||||
module = None
|
module = None
|
||||||
named_parameters = None
|
named_parameters = None
|
||||||
|
|
|
@ -27,7 +27,13 @@ class StrategiesConstructor:
|
||||||
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
||||||
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||||
"""
|
"""
|
||||||
meta_tensor = node._meta_data
|
if hasattr(node, '_meta_data'):
|
||||||
|
meta_tensor = node._meta_data
|
||||||
|
elif isinstance(node, torch.Tensor):
|
||||||
|
meta_tensor = node
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'We cannot generate sharding spec for {type(node)} type.')
|
||||||
|
|
||||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||||
entire_shape=meta_tensor.shape,
|
entire_shape=meta_tensor.shape,
|
||||||
dim_partition_dict=dim_partition_dict)
|
dim_partition_dict=dim_partition_dict)
|
||||||
|
|
Loading…
Reference in New Issue