mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added new node handler (#1612)
parent
7d1bb71d5d
commit
d397842fa8
|
@ -0,0 +1,66 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from torch.fx.node import Node
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from typing import Dict, List
|
||||||
|
from ..sharding_strategy import StrategiesVector, Operand, StrategyGenerator_V2
|
||||||
|
|
||||||
|
|
||||||
|
class NodeHandler(ABC):
|
||||||
|
'''
|
||||||
|
The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (Node): the input node in node argument list.
|
||||||
|
device_mesh (DeviceMesh): A logical view of a physical mesh.
|
||||||
|
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node: Node,
|
||||||
|
device_mesh: DeviceMesh,
|
||||||
|
strategies_vector: StrategiesVector,
|
||||||
|
) -> None:
|
||||||
|
self.node = node
|
||||||
|
self.predecessor_node = list(node._input_nodes.keys())
|
||||||
|
self.successor_node = list(node.users.keys())
|
||||||
|
self.device_mesh = device_mesh
|
||||||
|
self.strategies_vector = strategies_vector
|
||||||
|
self.strategy_generator = self.register_strategy_generator()
|
||||||
|
|
||||||
|
def register_strategy(self) -> StrategiesVector:
|
||||||
|
"""
|
||||||
|
Register different sharding strategies for the current node.
|
||||||
|
"""
|
||||||
|
operand_mapping = self.get_operand_mapping()
|
||||||
|
for generator in self.strategy_generator:
|
||||||
|
strategies = generator.generate(operand_mapping)
|
||||||
|
self.strategies_vector.extend(strategies)
|
||||||
|
return self.strategies_vector
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||||
|
"""
|
||||||
|
Define which generators should be used by this NodeHandler object.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_operand_mapping(self) -> Dict[str, Operand]:
|
||||||
|
"""
|
||||||
|
Returns the mapping between the logical operand name to its physical operands.
|
||||||
|
A logical operand is defined by the strategy generator, for example, a matrix multiplication
|
||||||
|
operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is
|
||||||
|
the module input and the physical operand for "other" is the module weight.
|
||||||
|
Note that the operand name is specified by the StrategyGenerator object.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
# for a linear layer
|
||||||
|
mapping = {
|
||||||
|
"input": Operand(name=str(self.node.args[0]), type=OperandType.ARG),
|
||||||
|
"other": Operand(name="weight", type=OperandType.PARAM),
|
||||||
|
"bias": Operand(name="bias", type=OperandType.PARAM)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
pass
|
|
@ -0,0 +1,25 @@
|
||||||
|
class Registry:
|
||||||
|
# TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
self.store = {}
|
||||||
|
|
||||||
|
def register(self, source):
|
||||||
|
|
||||||
|
def wrapper(func):
|
||||||
|
self.store[source] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def get(self, source):
|
||||||
|
assert source in self.store
|
||||||
|
target = self.store[source]
|
||||||
|
return target
|
||||||
|
|
||||||
|
def has(self, source):
|
||||||
|
return source in self.store
|
||||||
|
|
||||||
|
|
||||||
|
operator_registry = Registry('operator')
|
|
@ -1,4 +1,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from typing import Dict, List, Union, Tuple, Any
|
from typing import Dict, List, Union, Tuple, Any
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
@ -37,6 +40,20 @@ class ShardingStrategy:
|
||||||
input_shardings: List[ShardingSpec] = None
|
input_shardings: List[ShardingSpec] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OperandType(Enum):
|
||||||
|
"""
|
||||||
|
An operand can come from the argument list of an operator or the parameter list of a module.
|
||||||
|
"""
|
||||||
|
ARG = 0
|
||||||
|
PARAM = 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Operand:
|
||||||
|
name: str
|
||||||
|
type: OperandType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainCycleItem:
|
class TrainCycleItem:
|
||||||
"""
|
"""
|
||||||
|
@ -44,9 +61,8 @@ class TrainCycleItem:
|
||||||
in a training iteration.
|
in a training iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fwd (Any): the item for the forward pass
|
fwd (float): the item for the forward pass
|
||||||
bwd (Any): the item for the backward pass
|
bwd (float): the item for the backward pass
|
||||||
total (Any): the total value for the forward and backward pass
|
|
||||||
"""
|
"""
|
||||||
fwd: Any
|
fwd: Any
|
||||||
bwd: Any
|
bwd: Any
|
||||||
|
@ -74,8 +90,33 @@ class ShardingStrategy_V2:
|
||||||
compute_cost: TrainCycleItem = None
|
compute_cost: TrainCycleItem = None
|
||||||
communication_cost: TrainCycleItem = None
|
communication_cost: TrainCycleItem = None
|
||||||
memory_cost: TrainCycleItem = None
|
memory_cost: TrainCycleItem = None
|
||||||
input_sharding_specs: List[ShardingSpec] = None
|
input_sharding_specs: Dict[Operand, ShardingSpec] = None
|
||||||
input_resharding_costs: Dict[Node, List[float]] = None
|
input_resharding_costs: Dict[Operand, List[float]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyGenerator_V2(ABC):
|
||||||
|
"""
|
||||||
|
StrategyGenerator is used to generate the same group of sharding strategies.
|
||||||
|
|
||||||
|
TODO: remove the original strategy_generator.py after refactoring
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device_mesh: DeviceMesh):
|
||||||
|
self.device_mesh = device_mesh
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, operand_mapping: Dict[str:Operand]) -> List[ShardingStrategy_V2]:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate(self, *args, **kwargs) -> bool:
|
||||||
|
"""
|
||||||
|
Validate if the operands are of desired shape.
|
||||||
|
If True, means this generator can be used for the current operation.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StrategiesVector(list):
|
class StrategiesVector(list):
|
||||||
|
|
Loading…
Reference in New Issue