mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added new node handler (#1612)
parent
7d1bb71d5d
commit
d397842fa8
|
@ -0,0 +1,66 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from torch.fx.node import Node
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from typing import Dict, List
|
||||
from ..sharding_strategy import StrategiesVector, Operand, StrategyGenerator_V2
|
||||
|
||||
|
||||
class NodeHandler(ABC):
|
||||
'''
|
||||
The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
|
||||
|
||||
Args:
|
||||
node (Node): the input node in node argument list.
|
||||
device_mesh (DeviceMesh): A logical view of a physical mesh.
|
||||
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node: Node,
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_vector: StrategiesVector,
|
||||
) -> None:
|
||||
self.node = node
|
||||
self.predecessor_node = list(node._input_nodes.keys())
|
||||
self.successor_node = list(node.users.keys())
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.strategy_generator = self.register_strategy_generator()
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
"""
|
||||
Register different sharding strategies for the current node.
|
||||
"""
|
||||
operand_mapping = self.get_operand_mapping()
|
||||
for generator in self.strategy_generator:
|
||||
strategies = generator.generate(operand_mapping)
|
||||
self.strategies_vector.extend(strategies)
|
||||
return self.strategies_vector
|
||||
|
||||
@abstractmethod
|
||||
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||
"""
|
||||
Define which generators should be used by this NodeHandler object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_operand_mapping(self) -> Dict[str, Operand]:
|
||||
"""
|
||||
Returns the mapping between the logical operand name to its physical operands.
|
||||
A logical operand is defined by the strategy generator, for example, a matrix multiplication
|
||||
operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is
|
||||
the module input and the physical operand for "other" is the module weight.
|
||||
Note that the operand name is specified by the StrategyGenerator object.
|
||||
|
||||
For example:
|
||||
|
||||
# for a linear layer
|
||||
mapping = {
|
||||
"input": Operand(name=str(self.node.args[0]), type=OperandType.ARG),
|
||||
"other": Operand(name="weight", type=OperandType.PARAM),
|
||||
"bias": Operand(name="bias", type=OperandType.PARAM)
|
||||
}
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,25 @@
|
|||
class Registry:
|
||||
# TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.store = {}
|
||||
|
||||
def register(self, source):
|
||||
|
||||
def wrapper(func):
|
||||
self.store[source] = func
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def get(self, source):
|
||||
assert source in self.store
|
||||
target = self.store[source]
|
||||
return target
|
||||
|
||||
def has(self, source):
|
||||
return source in self.store
|
||||
|
||||
|
||||
operator_registry = Registry('operator')
|
|
@ -1,4 +1,7 @@
|
|||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from typing import Dict, List, Union, Tuple, Any
|
||||
from torch.fx.node import Node
|
||||
|
@ -37,6 +40,20 @@ class ShardingStrategy:
|
|||
input_shardings: List[ShardingSpec] = None
|
||||
|
||||
|
||||
class OperandType(Enum):
|
||||
"""
|
||||
An operand can come from the argument list of an operator or the parameter list of a module.
|
||||
"""
|
||||
ARG = 0
|
||||
PARAM = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Operand:
|
||||
name: str
|
||||
type: OperandType
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainCycleItem:
|
||||
"""
|
||||
|
@ -44,9 +61,8 @@ class TrainCycleItem:
|
|||
in a training iteration.
|
||||
|
||||
Args:
|
||||
fwd (Any): the item for the forward pass
|
||||
bwd (Any): the item for the backward pass
|
||||
total (Any): the total value for the forward and backward pass
|
||||
fwd (float): the item for the forward pass
|
||||
bwd (float): the item for the backward pass
|
||||
"""
|
||||
fwd: Any
|
||||
bwd: Any
|
||||
|
@ -74,8 +90,33 @@ class ShardingStrategy_V2:
|
|||
compute_cost: TrainCycleItem = None
|
||||
communication_cost: TrainCycleItem = None
|
||||
memory_cost: TrainCycleItem = None
|
||||
input_sharding_specs: List[ShardingSpec] = None
|
||||
input_resharding_costs: Dict[Node, List[float]] = None
|
||||
input_sharding_specs: Dict[Operand, ShardingSpec] = None
|
||||
input_resharding_costs: Dict[Operand, List[float]] = None
|
||||
|
||||
|
||||
class StrategyGenerator_V2(ABC):
|
||||
"""
|
||||
StrategyGenerator is used to generate the same group of sharding strategies.
|
||||
|
||||
TODO: remove the original strategy_generator.py after refactoring
|
||||
"""
|
||||
|
||||
def __init__(self, device_mesh: DeviceMesh):
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, operand_mapping: Dict[str:Operand]) -> List[ShardingStrategy_V2]:
|
||||
"""
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, *args, **kwargs) -> bool:
|
||||
"""
|
||||
Validate if the operands are of desired shape.
|
||||
If True, means this generator can be used for the current operation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StrategiesVector(list):
|
||||
|
|
Loading…
Reference in New Issue