2022-10-12 03:16:18 +00:00
|
|
|
from copy import deepcopy
|
2022-08-19 07:51:54 +00:00
|
|
|
from dataclasses import dataclass
|
2022-09-20 06:17:21 +00:00
|
|
|
from enum import Enum
|
2022-10-14 05:27:00 +00:00
|
|
|
from typing import Any, Dict, List, Tuple, Union
|
2022-09-21 04:23:21 +00:00
|
|
|
|
2022-10-14 05:27:00 +00:00
|
|
|
import torch
|
2022-10-20 10:48:18 +00:00
|
|
|
from torch.fx.node import Node
|
|
|
|
|
2022-10-14 05:27:00 +00:00
|
|
|
from colossalai.tensor.shape_consistency import CommSpec
|
2022-08-19 07:51:54 +00:00
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
2022-10-14 05:27:00 +00:00
|
|
|
|
2022-10-20 10:48:18 +00:00
|
|
|
from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP
|
2022-08-23 06:23:08 +00:00
|
|
|
|
2022-10-13 10:24:11 +00:00
|
|
|
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
|
2022-08-19 06:57:23 +00:00
|
|
|
|
|
|
|
|
2022-09-21 04:23:21 +00:00
|
|
|
class OperationDataType(Enum):
|
2022-09-20 06:17:21 +00:00
|
|
|
"""
|
2022-09-21 04:23:21 +00:00
|
|
|
An operation can come from the argument list of an operator or the parameter list of a module.
|
2022-09-20 06:17:21 +00:00
|
|
|
"""
|
2022-09-28 06:01:36 +00:00
|
|
|
INPUT = 0
|
|
|
|
ARG = 1
|
|
|
|
PARAM = 2
|
2022-10-17 05:37:38 +00:00
|
|
|
BUFFER = 3
|
|
|
|
OUTPUT = 4
|
2022-09-20 06:17:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2022-09-21 04:23:21 +00:00
|
|
|
class OperationData:
|
|
|
|
"""
|
|
|
|
OperationData is the data related to an operator, the data can be the operand or the output.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name (str): the name of the operation-related data
|
|
|
|
type (OperationDataType): the type of the operation data
|
2022-10-09 06:49:18 +00:00
|
|
|
data (Any): the value for this data, usually it is a meta tensor.
|
2022-09-21 04:23:21 +00:00
|
|
|
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
|
|
|
|
"""
|
2022-09-20 06:17:21 +00:00
|
|
|
name: str
|
2022-09-21 04:23:21 +00:00
|
|
|
type: OperationDataType
|
2022-10-09 06:49:18 +00:00
|
|
|
data: Any
|
2022-09-21 04:23:21 +00:00
|
|
|
logical_shape: Tuple[int] = None
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
# if no logical shape is specified, use the data shape as the logical shape
|
2022-10-09 06:49:18 +00:00
|
|
|
if self.logical_shape is None and isinstance(self.data, torch.Tensor):
|
2022-09-21 04:23:21 +00:00
|
|
|
self.logical_shape = self.data.shape
|
2022-09-20 06:17:21 +00:00
|
|
|
|
2022-09-27 04:06:25 +00:00
|
|
|
def __repr__(self) -> str:
|
|
|
|
return f'OperationData(name={self.name}, type={self.type})'
|
|
|
|
|
2022-10-09 06:49:18 +00:00
|
|
|
def __eq__(self, other) -> bool:
|
|
|
|
return other.name == self.name
|
|
|
|
|
2022-09-27 04:06:25 +00:00
|
|
|
def __hash__(self) -> int:
|
2022-10-09 06:49:18 +00:00
|
|
|
return hash(f'{self.name}')
|
2022-09-27 04:06:25 +00:00
|
|
|
|
2022-09-20 06:17:21 +00:00
|
|
|
|
2022-09-20 03:20:54 +00:00
|
|
|
@dataclass
|
|
|
|
class TrainCycleItem:
|
|
|
|
"""
|
|
|
|
TrainCycleItem is a dataclass to store the items which have different values for the forward and backward pass
|
|
|
|
in a training iteration.
|
|
|
|
|
|
|
|
Args:
|
2022-09-20 06:17:21 +00:00
|
|
|
fwd (float): the item for the forward pass
|
|
|
|
bwd (float): the item for the backward pass
|
2022-09-20 03:20:54 +00:00
|
|
|
"""
|
|
|
|
fwd: Any
|
|
|
|
bwd: Any
|
|
|
|
total: Any
|
|
|
|
|
|
|
|
|
2022-09-21 04:23:21 +00:00
|
|
|
@dataclass
|
2022-09-26 08:58:14 +00:00
|
|
|
class MemoryCost:
|
2022-09-21 04:23:21 +00:00
|
|
|
"""
|
2022-10-14 05:27:00 +00:00
|
|
|
MemoryCost is a dataclass which stores the memory usage in the program.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
activation (int): the memory cost incurred by the activations in bytes.
|
|
|
|
parameter (int): the memory cost incurred by the module parameter in bytes.
|
2022-11-04 02:55:09 +00:00
|
|
|
temp (int): the memory cost incurred by the temporary tensors in bytes.
|
|
|
|
buffer (int): the memory cost incurred by the module buffer in bytes.
|
2022-09-21 04:23:21 +00:00
|
|
|
"""
|
2022-09-26 08:58:14 +00:00
|
|
|
activation: int = 0
|
|
|
|
parameter: int = 0
|
2022-11-04 02:55:09 +00:00
|
|
|
temp: int = 0
|
2022-10-17 05:37:38 +00:00
|
|
|
buffer: int = 0
|
2022-09-21 04:23:21 +00:00
|
|
|
|
|
|
|
|
2022-10-20 10:48:18 +00:00
|
|
|
class CommType(Enum):
|
|
|
|
"""
|
|
|
|
CommType describes the sequential order of a communication action and a computation action.
|
|
|
|
|
|
|
|
Meaning:
|
|
|
|
BEFORE: the communication action happens just before the computation operation.
|
|
|
|
AFTER: the communication action happens after the computation operation.
|
|
|
|
HOOK: the communication action is used to do the grad all reduce.
|
|
|
|
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
|
|
|
|
"""
|
|
|
|
BEFORE = 0
|
|
|
|
AFTER = 1
|
|
|
|
HOOK = 2
|
|
|
|
IMPLICIT = 3
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class CommAction:
|
|
|
|
"""
|
|
|
|
CommAction is used to record the communication action.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
comm_spec: express the communication pattern and the process groups to execute the communication action.
|
|
|
|
comm_type: describes the sequential order of a communication action and a computation action.
|
|
|
|
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
|
|
|
|
because the args of node may be changed by graph transform passes.
|
|
|
|
"""
|
|
|
|
comm_spec: CommSpec = None
|
|
|
|
comm_type: CommType = None
|
|
|
|
arg_index: int = -1
|
2022-10-27 02:42:54 +00:00
|
|
|
key_for_kwarg: any = None
|
2022-10-20 10:48:18 +00:00
|
|
|
|
|
|
|
|
2022-09-20 03:20:54 +00:00
|
|
|
@dataclass
|
2022-10-13 10:24:11 +00:00
|
|
|
class ShardingStrategy:
|
2022-09-20 03:20:54 +00:00
|
|
|
"""
|
|
|
|
ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name (str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
|
|
|
|
output_sharding_spec (ShardingSpec): ShardingSpec of the output node.
|
|
|
|
compute_cost (TrainCycleItem): Computation cost to complete this strategy. (default to None)
|
|
|
|
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
|
|
|
|
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
|
|
|
|
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
|
|
|
"""
|
|
|
|
name: str
|
2022-10-09 06:49:18 +00:00
|
|
|
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
|
2022-09-20 03:20:54 +00:00
|
|
|
compute_cost: TrainCycleItem = None
|
|
|
|
communication_cost: TrainCycleItem = None
|
|
|
|
memory_cost: TrainCycleItem = None
|
2022-10-20 10:48:18 +00:00
|
|
|
communication_actions: Dict[OperationData, CommAction] = None
|
2022-10-13 05:42:36 +00:00
|
|
|
resharding_costs: Dict[Node, List[TrainCycleItem]] = None
|
2022-09-21 04:23:21 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
|
|
|
specs = {}
|
|
|
|
specs.update(self._get_sharding_spec(OperationDataType.ARG))
|
|
|
|
specs.update(self._get_sharding_spec(OperationDataType.PARAM))
|
|
|
|
return specs
|
|
|
|
|
|
|
|
@property
|
|
|
|
def argument_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
|
|
|
return self._get_sharding_spec(OperationDataType.ARG)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def param_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
|
|
|
return self._get_sharding_spec(OperationDataType.PARAM)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def output_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
|
|
|
return self._get_sharding_spec(OperationDataType.OUTPUT)
|
|
|
|
|
|
|
|
def _get_sharding_spec(self, operation_data_type: OperationDataType):
|
|
|
|
specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type}
|
|
|
|
return specs
|
2022-09-20 06:17:21 +00:00
|
|
|
|
2022-09-28 11:55:44 +00:00
|
|
|
def get_op_data_by_name(self, name: str):
|
|
|
|
for op_data in self.sharding_specs.keys():
|
|
|
|
if op_data.name == name:
|
|
|
|
return op_data
|
|
|
|
raise KeyError(f"Could not find the OperationData with name {name}")
|
|
|
|
|
|
|
|
def get_sharding_spec_by_name(self, name: str):
|
|
|
|
for op_data, sharding_spec in self.sharding_specs.items():
|
|
|
|
if op_data.name == name:
|
|
|
|
return sharding_spec
|
|
|
|
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
|
|
|
|
|
2022-10-12 03:16:18 +00:00
|
|
|
def clone(self):
|
|
|
|
|
|
|
|
def _deepcopy_dict_vals(data: Dict):
|
|
|
|
return {k: deepcopy(v) for k, v in data.items()}
|
|
|
|
|
|
|
|
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs else None
|
|
|
|
communication_actions = _deepcopy_dict_vals(self.communication_actions) if self.communication_actions else None
|
|
|
|
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs else None
|
|
|
|
compute_cost = deepcopy(self.compute_cost)
|
|
|
|
communication_cost = deepcopy(self.communication_cost)
|
|
|
|
memory_cost = deepcopy(self.memory_cost)
|
|
|
|
|
2022-10-13 10:24:11 +00:00
|
|
|
return ShardingStrategy(name=self.name,
|
|
|
|
sharding_specs=sharding_specs,
|
|
|
|
compute_cost=compute_cost,
|
|
|
|
communication_cost=communication_cost,
|
|
|
|
memory_cost=memory_cost,
|
|
|
|
communication_actions=communication_actions,
|
|
|
|
resharding_costs=resharding_costs)
|
2022-10-12 03:16:18 +00:00
|
|
|
|
2022-09-20 06:17:21 +00:00
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
class StrategiesVector(list):
|
2022-08-19 06:57:23 +00:00
|
|
|
'''
|
|
|
|
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
|
|
|
strategies of the node.
|
|
|
|
|
|
|
|
Argument:
|
2022-08-23 06:23:08 +00:00
|
|
|
node (Node): node for which the list of sharding strategies are generated.
|
2022-08-19 06:57:23 +00:00
|
|
|
'''
|
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
def __init__(self, node: Node):
|
|
|
|
super().__init__()
|
2022-08-19 06:57:23 +00:00
|
|
|
self.node = node
|
2022-08-23 06:23:08 +00:00
|
|
|
# fetch its input and output nodes
|
2022-08-30 08:32:09 +00:00
|
|
|
# TODO: placeholder input nodes
|
2022-08-23 06:23:08 +00:00
|
|
|
self.predecessor_nodes = list(node._input_nodes.keys())
|
2022-09-29 02:43:25 +00:00
|
|
|
if self.node.op == 'output':
|
|
|
|
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
|
2022-08-25 09:19:59 +00:00
|
|
|
self.successor_nodes = list(node.users.keys())
|
2022-08-19 06:57:23 +00:00
|
|
|
|
|
|
|
def check_merge(self):
|
2022-08-30 08:32:09 +00:00
|
|
|
merge_label = False
|
|
|
|
if self.node.op == 'call_module':
|
|
|
|
target = self.node.target
|
|
|
|
root_module = self.node.graph.owning_module
|
|
|
|
submod = root_module.get_submodule(target)
|
|
|
|
submod_type = type(submod)
|
2022-09-14 02:25:45 +00:00
|
|
|
# merge elementwise module node into source nodes
|
|
|
|
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
2022-08-30 08:32:09 +00:00
|
|
|
if submod_type in ELEMENTWISE_MODULE_OP:
|
|
|
|
merge_label = True
|
|
|
|
|
|
|
|
if self.node.op == 'call_function':
|
2022-09-14 02:25:45 +00:00
|
|
|
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
2022-08-30 08:32:09 +00:00
|
|
|
if self.node.target in ELEMENTWISE_FUNC_OP:
|
|
|
|
merge_label = True
|
2022-09-16 03:33:01 +00:00
|
|
|
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
|
|
|
|
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
|
|
|
|
merge_label = True
|
2022-09-14 02:25:45 +00:00
|
|
|
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
|
|
|
|
if self.node.target in RESHAPE_FUNC_OP:
|
|
|
|
merge_label = True
|
2022-08-30 08:32:09 +00:00
|
|
|
|
|
|
|
return merge_label
|