From 9dae9bb2bc8000b5f73eed8e7c262d1acf8841c7 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 19 Aug 2022 16:51:38 +0800 Subject: [PATCH] [autoparallel] introduced baseclass for op handler and reduced code redundancy (#1471) * [autoparallel] introduced baseclass for op handler and reduced code redundancy * polish code --- .../auto_parallel/solver/conv_handler.py | 94 +++++-------------- .../auto_parallel/solver/dot_handler.py | 12 +++ .../auto_parallel/solver/operator_handler.py | 45 +++++++++ 3 files changed, 81 insertions(+), 70 deletions(-) create mode 100644 colossalai/auto_parallel/solver/dot_handler.py create mode 100644 colossalai/auto_parallel/solver/operator_handler.py diff --git a/colossalai/auto_parallel/solver/conv_handler.py b/colossalai/auto_parallel/solver/conv_handler.py index a6f3a682b..228471870 100644 --- a/colossalai/auto_parallel/solver/conv_handler.py +++ b/colossalai/auto_parallel/solver/conv_handler.py @@ -1,35 +1,19 @@ +from lib2to3.pytree import Base import operator from functools import reduce import torch from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from .operator_handler import OperatorHanlder -class ConvHandler: - ''' - The ConvHandler is used to generate every possible strategies for a Conv node. - - Argument: - input_node(Node): the input node in conv node argument list. - input_index(int): the index of input node in the conv node argument list. - weight(torch.Tensor): Weight of the conv node. - output_node(Node): Output_node is the output of the conv node. - device_mesh(DeviceMesh): A logical view of a physical mesh. - strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. - shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs. - ''' - - def __init__(self, input_node, input_index, weight, output_node, device_mesh, strategies_vector, - shape_consistency_manager): - self.input_node = input_node - self.input_data = self.input_node._meta_data - self.weight = weight - self.input_index = input_index - self.output_node = output_node - self.output = self.output_node._meta_data - self.device_mesh = device_mesh - self.strategies_vector = strategies_vector - self.shape_consistency_manager = shape_consistency_manager +class ConvHandler(OperatorHanlder): + """ + A OperatorHandler which deals with the sharding strategies of linear matrix multiplication. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self._sanity_check() def _sanity_check(self): @@ -42,36 +26,6 @@ class ConvHandler: assert self.input_data.dim() in (3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' - def _generate_sharding_spec_for_input(self, dim_partition_dict_for_input): - ''' - Generate sharding spec for the input node. - ''' - entire_shape_for_input = self.input_data.shape - sharding_spec_for_input = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=entire_shape_for_input, - dim_partition_dict=dim_partition_dict_for_input) - return sharding_spec_for_input - - def _generate_sharding_spec_for_weight(self, dim_partition_dict_for_weight): - ''' - Generate sharding spec for the weight. - ''' - entire_shape_for_weight = self.weight.shape - sharding_spec_for_weight = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=entire_shape_for_weight, - dim_partition_dict=dim_partition_dict_for_weight) - return sharding_spec_for_weight - - def _generate_sharding_spec_for_output(self, dim_partition_dict_for_output): - ''' - Generate sharding spec for the output node. - ''' - entire_shape_for_output = self.output.shape - sharding_spec_for_output = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=entire_shape_for_output, - dim_partition_dict=dim_partition_dict_for_output) - return sharding_spec_for_output - def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input): ''' Compute the resharding costs with this specific strategy. @@ -120,13 +74,13 @@ class ConvHandler: name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = {} @@ -160,13 +114,13 @@ class ConvHandler: name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0]} - sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = {} @@ -200,13 +154,13 @@ class ConvHandler: name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = {} @@ -240,13 +194,13 @@ class ConvHandler: name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim_0]} - sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0]} - sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = {} @@ -281,13 +235,13 @@ class ConvHandler: name = f'RR = RR x RR' dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = {} diff --git a/colossalai/auto_parallel/solver/dot_handler.py b/colossalai/auto_parallel/solver/dot_handler.py new file mode 100644 index 000000000..a25466ea3 --- /dev/null +++ b/colossalai/auto_parallel/solver/dot_handler.py @@ -0,0 +1,12 @@ +from .operator_handler import OperatorHanlder + + +class DotHandler(OperatorHanlder): + """ + A OperatorHandler which deals with the sharding strategies of linear matrix multiplication. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: refactor the dot handler in my local branch to align with the latest main branch diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/operator_handler.py new file mode 100644 index 000000000..24027e996 --- /dev/null +++ b/colossalai/auto_parallel/solver/operator_handler.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod +from torch.fx.node import Node +import torch.nn as nn +from colossalai.device.device_mesh import DeviceMesh +from .sharding_strategy import StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + + +class OperatorHanlder(ABC): + ''' + The OperatorHanlder is an abstract class used to generate every possible strategies for a operator node. + + Argument: + input_node(Node): the input node in node argument list. + input_index(int): the index of input node in the node argument list. + weight(torch.Tensor): Weight of the node. + output_node(Node): Output_node is the output of the node. + device_mesh(DeviceMesh): A logical view of a physical mesh. + strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. + shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs. + ''' + + def __init__(self, input_node: Node, input_index: int, weight: nn.Parameter, output_node: Node, + device_mesh: DeviceMesh, strategies_vector: StrategiesVector, + shape_consistency_manager: ShapeConsistencyManager): + self.input_node = input_node + self.input_data = self.input_node._meta_data + self.weight = weight + self.input_index = input_index + self.output_node = output_node + self.output = self.output_node._meta_data + self.device_mesh = device_mesh + self.strategies_vector = strategies_vector + self.shape_consistency_manager = shape_consistency_manager + + @abstractmethod + def register_strategy_into_strategies_vector(self): + pass + + def _generate_sharding_spec(self, tensor, dim_partition_dict): + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=tensor.shape, + dim_partition_dict=dim_partition_dict) + return sharding_spec