[autoparallel] added new node handler (#1612)

pull/1614/head
Frank Lee 2022-09-20 14:17:21 +08:00 committed by GitHub
parent 7d1bb71d5d
commit d397842fa8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 137 additions and 5 deletions

View File

@ -0,0 +1,66 @@
from abc import ABC, abstractmethod
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List
from ..sharding_strategy import StrategiesVector, Operand, StrategyGenerator_V2
class NodeHandler(ABC):
'''
The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
Args:
node (Node): the input node in node argument list.
device_mesh (DeviceMesh): A logical view of a physical mesh.
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
'''
def __init__(
self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.strategy_generator = self.register_strategy_generator()
def register_strategy(self) -> StrategiesVector:
"""
Register different sharding strategies for the current node.
"""
operand_mapping = self.get_operand_mapping()
for generator in self.strategy_generator:
strategies = generator.generate(operand_mapping)
self.strategies_vector.extend(strategies)
return self.strategies_vector
@abstractmethod
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
"""
Define which generators should be used by this NodeHandler object.
"""
pass
@abstractmethod
def get_operand_mapping(self) -> Dict[str, Operand]:
"""
Returns the mapping between the logical operand name to its physical operands.
A logical operand is defined by the strategy generator, for example, a matrix multiplication
operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is
the module input and the physical operand for "other" is the module weight.
Note that the operand name is specified by the StrategyGenerator object.
For example:
# for a linear layer
mapping = {
"input": Operand(name=str(self.node.args[0]), type=OperandType.ARG),
"other": Operand(name="weight", type=OperandType.PARAM),
"bias": Operand(name="bias", type=OperandType.PARAM)
}
"""
pass

View File

@ -0,0 +1,25 @@
class Registry:
# TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
self.store[source] = func
return func
return wrapper
def get(self, source):
assert source in self.store
target = self.store[source]
return target
def has(self, source):
return source in self.store
operator_registry = Registry('operator')

View File

@ -1,4 +1,7 @@
from dataclasses import dataclass
from abc import ABC, abstractmethod
from enum import Enum
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from typing import Dict, List, Union, Tuple, Any
from torch.fx.node import Node
@ -37,6 +40,20 @@ class ShardingStrategy:
input_shardings: List[ShardingSpec] = None
class OperandType(Enum):
"""
An operand can come from the argument list of an operator or the parameter list of a module.
"""
ARG = 0
PARAM = 1
@dataclass
class Operand:
name: str
type: OperandType
@dataclass
class TrainCycleItem:
"""
@ -44,9 +61,8 @@ class TrainCycleItem:
in a training iteration.
Args:
fwd (Any): the item for the forward pass
bwd (Any): the item for the backward pass
total (Any): the total value for the forward and backward pass
fwd (float): the item for the forward pass
bwd (float): the item for the backward pass
"""
fwd: Any
bwd: Any
@ -74,8 +90,33 @@ class ShardingStrategy_V2:
compute_cost: TrainCycleItem = None
communication_cost: TrainCycleItem = None
memory_cost: TrainCycleItem = None
input_sharding_specs: List[ShardingSpec] = None
input_resharding_costs: Dict[Node, List[float]] = None
input_sharding_specs: Dict[Operand, ShardingSpec] = None
input_resharding_costs: Dict[Operand, List[float]] = None
class StrategyGenerator_V2(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
TODO: remove the original strategy_generator.py after refactoring
"""
def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
@abstractmethod
def generate(self, operand_mapping: Dict[str:Operand]) -> List[ShardingStrategy_V2]:
"""
"""
pass
@abstractmethod
def validate(self, *args, **kwargs) -> bool:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass
class StrategiesVector(list):