mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] introduced baseclass for op handler and reduced code redundancy (#1471)
* [autoparallel] introduced baseclass for op handler and reduced code redundancy * polish codepull/1472/head
parent
3a54e1c9b7
commit
9dae9bb2bc
|
@ -1,35 +1,19 @@
|
|||
from lib2to3.pytree import Base
|
||||
import operator
|
||||
from functools import reduce
|
||||
import torch
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHanlder
|
||||
|
||||
|
||||
class ConvHandler:
|
||||
'''
|
||||
The ConvHandler is used to generate every possible strategies for a Conv node.
|
||||
class ConvHandler(OperatorHanlder):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
||||
"""
|
||||
|
||||
Argument:
|
||||
input_node(Node): the input node in conv node argument list.
|
||||
input_index(int): the index of input node in the conv node argument list.
|
||||
weight(torch.Tensor): Weight of the conv node.
|
||||
output_node(Node): Output_node is the output of the conv node.
|
||||
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
||||
strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
|
||||
'''
|
||||
|
||||
def __init__(self, input_node, input_index, weight, output_node, device_mesh, strategies_vector,
|
||||
shape_consistency_manager):
|
||||
self.input_node = input_node
|
||||
self.input_data = self.input_node._meta_data
|
||||
self.weight = weight
|
||||
self.input_index = input_index
|
||||
self.output_node = output_node
|
||||
self.output = self.output_node._meta_data
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.shape_consistency_manager = shape_consistency_manager
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._sanity_check()
|
||||
|
||||
def _sanity_check(self):
|
||||
|
@ -42,36 +26,6 @@ class ConvHandler:
|
|||
assert self.input_data.dim() in (3, 4,
|
||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||
|
||||
def _generate_sharding_spec_for_input(self, dim_partition_dict_for_input):
|
||||
'''
|
||||
Generate sharding spec for the input node.
|
||||
'''
|
||||
entire_shape_for_input = self.input_data.shape
|
||||
sharding_spec_for_input = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=entire_shape_for_input,
|
||||
dim_partition_dict=dim_partition_dict_for_input)
|
||||
return sharding_spec_for_input
|
||||
|
||||
def _generate_sharding_spec_for_weight(self, dim_partition_dict_for_weight):
|
||||
'''
|
||||
Generate sharding spec for the weight.
|
||||
'''
|
||||
entire_shape_for_weight = self.weight.shape
|
||||
sharding_spec_for_weight = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=entire_shape_for_weight,
|
||||
dim_partition_dict=dim_partition_dict_for_weight)
|
||||
return sharding_spec_for_weight
|
||||
|
||||
def _generate_sharding_spec_for_output(self, dim_partition_dict_for_output):
|
||||
'''
|
||||
Generate sharding spec for the output node.
|
||||
'''
|
||||
entire_shape_for_output = self.output.shape
|
||||
sharding_spec_for_output = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=entire_shape_for_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
return sharding_spec_for_output
|
||||
|
||||
def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
@ -120,13 +74,13 @@ class ConvHandler:
|
|||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
|
@ -160,13 +114,13 @@ class ConvHandler:
|
|||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
|
@ -200,13 +154,13 @@ class ConvHandler:
|
|||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
|
@ -240,13 +194,13 @@ class ConvHandler:
|
|||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
|
@ -281,13 +235,13 @@ class ConvHandler:
|
|||
name = f'RR = RR x RR'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
from .operator_handler import OperatorHanlder
|
||||
|
||||
|
||||
class DotHandler(OperatorHanlder):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: refactor the dot handler in my local branch to align with the latest main branch
|
|
@ -0,0 +1,45 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from torch.fx.node import Node
|
||||
import torch.nn as nn
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from .sharding_strategy import StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
class OperatorHanlder(ABC):
|
||||
'''
|
||||
The OperatorHanlder is an abstract class used to generate every possible strategies for a operator node.
|
||||
|
||||
Argument:
|
||||
input_node(Node): the input node in node argument list.
|
||||
input_index(int): the index of input node in the node argument list.
|
||||
weight(torch.Tensor): Weight of the node.
|
||||
output_node(Node): Output_node is the output of the node.
|
||||
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
||||
strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
|
||||
'''
|
||||
|
||||
def __init__(self, input_node: Node, input_index: int, weight: nn.Parameter, output_node: Node,
|
||||
device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||
shape_consistency_manager: ShapeConsistencyManager):
|
||||
self.input_node = input_node
|
||||
self.input_data = self.input_node._meta_data
|
||||
self.weight = weight
|
||||
self.input_index = input_index
|
||||
self.output_node = output_node
|
||||
self.output = self.output_node._meta_data
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.shape_consistency_manager = shape_consistency_manager
|
||||
|
||||
@abstractmethod
|
||||
def register_strategy_into_strategies_vector(self):
|
||||
pass
|
||||
|
||||
def _generate_sharding_spec(self, tensor, dim_partition_dict):
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=tensor.shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
Loading…
Reference in New Issue