diff --git a/colossalai/auto_parallel/solver/conv_handler.py b/colossalai/auto_parallel/solver/conv_handler.py index a00fe1862..6526e1018 100644 --- a/colossalai/auto_parallel/solver/conv_handler.py +++ b/colossalai/auto_parallel/solver/conv_handler.py @@ -9,7 +9,7 @@ __all__ = ['ConvHandler'] class ConvHandler(OperatorHandler): """ - A OperatorHandler which deals with the sharding strategies of linear matrix multiplication. + An OperatorHandler which deals with the sharding strategies of linear matrix multiplication. """ def __init__(self, *args, **kwargs): @@ -67,7 +67,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] @@ -106,7 +106,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] @@ -145,7 +145,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] @@ -184,7 +184,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] @@ -223,7 +223,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] @@ -261,7 +261,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] @@ -301,7 +301,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] @@ -340,7 +340,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]) @@ -380,7 +380,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/operator_handler.py index 675b71982..5c4cc7def 100644 --- a/colossalai/auto_parallel/solver/operator_handler.py +++ b/colossalai/auto_parallel/solver/operator_handler.py @@ -15,7 +15,7 @@ __all__ = ['OperatorHandler'] class OperatorHandler(ABC): ''' - The OperatorHandler is an abstract class used to generate every possible strategies for a operator node. + The OperatorHandler is an abstract class used to generate every possible strategies for an operator node. Argument: input_node(Node): the input node in node argument list. @@ -43,6 +43,10 @@ class OperatorHandler(ABC): named_parameters = list(module.named_parameters(recurse=False)) # convert named parameters from list to dict named_parameters = {k: v for k, v in named_parameters} + elif self.node.op == 'call_function': + module = None + parameters = list(self.node.args)[1] + named_parameters = {'weight': parameters._meta_data} else: module = None named_parameters = None diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index eca20ef3b..98cc43976 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -27,7 +27,13 @@ class StrategiesConstructor: Generate the sharding spec of the tensor based on the given dim_partition_dict where the key is the tensor dimension and the value is the mesh dimension for sharding. """ - meta_tensor = node._meta_data + if hasattr(node, '_meta_data'): + meta_tensor = node._meta_data + elif isinstance(node, torch.Tensor): + meta_tensor = node + else: + raise RuntimeError(f'We cannot generate sharding spec for {type(node)} type.') + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, entire_shape=meta_tensor.shape, dim_partition_dict=dim_partition_dict)