2022-08-19 07:51:54 +00:00
|
|
|
from dataclasses import dataclass
|
2022-09-20 06:17:21 +00:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from enum import Enum
|
2022-09-21 04:23:21 +00:00
|
|
|
import operator
|
|
|
|
import torch
|
|
|
|
from functools import reduce
|
|
|
|
|
2022-09-20 06:17:21 +00:00
|
|
|
from colossalai.device.device_mesh import DeviceMesh
|
2022-08-19 07:51:54 +00:00
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
2022-09-20 03:20:54 +00:00
|
|
|
from typing import Dict, List, Union, Tuple, Any
|
2022-08-23 06:23:08 +00:00
|
|
|
from torch.fx.node import Node
|
2022-08-30 08:32:09 +00:00
|
|
|
from .constants import *
|
2022-08-23 06:23:08 +00:00
|
|
|
|
|
|
|
__all__ = ['ShardingStrategy', 'StrategiesVector']
|
2022-08-19 07:51:54 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2022-08-19 06:57:23 +00:00
|
|
|
class ShardingStrategy:
|
|
|
|
'''
|
|
|
|
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
|
|
|
|
and costs information using in solver.
|
|
|
|
|
|
|
|
Argument:
|
|
|
|
name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
|
|
|
|
output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
|
|
|
|
compute_cost(float): Computation cost to complete this strategy.(default to 0)
|
|
|
|
communication_cost(float): Communication cost to complete this strategy.(default to 0)
|
|
|
|
memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
|
|
|
|
resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
|
|
|
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
|
|
|
strategy.(default to None)
|
|
|
|
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
|
|
|
'''
|
|
|
|
|
2022-08-19 07:51:54 +00:00
|
|
|
name: str
|
2022-08-30 08:32:09 +00:00
|
|
|
# TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
|
|
|
|
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
|
2022-08-19 07:51:54 +00:00
|
|
|
compute_cost: float = 0.
|
|
|
|
communication_cost: float = 0.
|
|
|
|
memory_cost: float = 0.
|
2022-08-30 08:32:09 +00:00
|
|
|
resharding_costs: Dict[Node, List[float]] = None
|
|
|
|
# sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
|
|
|
|
# Therefore, we could process them at the specific op(operator.getitem)
|
|
|
|
input_shardings: List[ShardingSpec] = None
|
2022-08-19 06:57:23 +00:00
|
|
|
|
|
|
|
|
2022-09-21 04:23:21 +00:00
|
|
|
class OperationDataType(Enum):
|
2022-09-20 06:17:21 +00:00
|
|
|
"""
|
2022-09-21 04:23:21 +00:00
|
|
|
An operation can come from the argument list of an operator or the parameter list of a module.
|
2022-09-20 06:17:21 +00:00
|
|
|
"""
|
|
|
|
ARG = 0
|
|
|
|
PARAM = 1
|
2022-09-21 04:23:21 +00:00
|
|
|
OUTPUT = 2
|
2022-09-20 06:17:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2022-09-21 04:23:21 +00:00
|
|
|
class OperationData:
|
|
|
|
"""
|
|
|
|
OperationData is the data related to an operator, the data can be the operand or the output.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name (str): the name of the operation-related data
|
|
|
|
type (OperationDataType): the type of the operation data
|
|
|
|
data (torch.Tensor): the value for this data, usually it is a meta tensor.
|
|
|
|
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
|
|
|
|
"""
|
2022-09-20 06:17:21 +00:00
|
|
|
name: str
|
2022-09-21 04:23:21 +00:00
|
|
|
type: OperationDataType
|
|
|
|
data: torch.Tensor
|
|
|
|
logical_shape: Tuple[int] = None
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
# if no logical shape is specified, use the data shape as the logical shape
|
|
|
|
if self.logical_shape is None:
|
|
|
|
self.logical_shape = self.data.shape
|
2022-09-20 06:17:21 +00:00
|
|
|
|
|
|
|
|
2022-09-20 03:20:54 +00:00
|
|
|
@dataclass
|
|
|
|
class TrainCycleItem:
|
|
|
|
"""
|
|
|
|
TrainCycleItem is a dataclass to store the items which have different values for the forward and backward pass
|
|
|
|
in a training iteration.
|
|
|
|
|
|
|
|
Args:
|
2022-09-20 06:17:21 +00:00
|
|
|
fwd (float): the item for the forward pass
|
|
|
|
bwd (float): the item for the backward pass
|
2022-09-20 03:20:54 +00:00
|
|
|
"""
|
|
|
|
fwd: Any
|
|
|
|
bwd: Any
|
|
|
|
total: Any
|
|
|
|
|
|
|
|
|
2022-09-21 04:23:21 +00:00
|
|
|
class CommunicationType(Enum):
|
|
|
|
FWD_ALL_REDUCE = 0
|
|
|
|
BWD_ALL_REDUCE = 1
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class CommunicationAction:
|
|
|
|
"""
|
|
|
|
The actions
|
|
|
|
"""
|
|
|
|
type: CommunicationType
|
|
|
|
mesh_dim: int
|
|
|
|
|
|
|
|
|
2022-09-20 03:20:54 +00:00
|
|
|
@dataclass
|
|
|
|
class ShardingStrategy_V2:
|
|
|
|
"""
|
|
|
|
ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name (str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
|
|
|
|
output_sharding_spec (ShardingSpec): ShardingSpec of the output node.
|
|
|
|
compute_cost (TrainCycleItem): Computation cost to complete this strategy. (default to None)
|
|
|
|
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
|
|
|
|
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
|
|
|
|
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
|
|
|
input_resharding_costs (Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
|
|
|
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
|
|
|
strategy.(default to None)
|
|
|
|
"""
|
|
|
|
name: str
|
2022-09-21 04:23:21 +00:00
|
|
|
sharding_specs: Dict[OperationData, ShardingSpec] = None
|
2022-09-20 03:20:54 +00:00
|
|
|
compute_cost: TrainCycleItem = None
|
|
|
|
communication_cost: TrainCycleItem = None
|
|
|
|
memory_cost: TrainCycleItem = None
|
2022-09-21 04:23:21 +00:00
|
|
|
input_resharding_costs: Dict[OperationData, List[float]] = None
|
|
|
|
communication_actions: Dict[OperationData, List[CommunicationAction]] = None
|
|
|
|
|
|
|
|
@property
|
|
|
|
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
|
|
|
specs = {}
|
|
|
|
specs.update(self._get_sharding_spec(OperationDataType.ARG))
|
|
|
|
specs.update(self._get_sharding_spec(OperationDataType.PARAM))
|
|
|
|
return specs
|
|
|
|
|
|
|
|
@property
|
|
|
|
def argument_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
|
|
|
return self._get_sharding_spec(OperationDataType.ARG)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def param_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
|
|
|
return self._get_sharding_spec(OperationDataType.PARAM)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def output_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
|
|
|
return self._get_sharding_spec(OperationDataType.OUTPUT)
|
|
|
|
|
|
|
|
def _get_sharding_spec(self, operation_data_type: OperationDataType):
|
|
|
|
specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type}
|
|
|
|
return specs
|
2022-09-20 06:17:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
class StrategyGenerator_V2(ABC):
|
|
|
|
"""
|
|
|
|
StrategyGenerator is used to generate the same group of sharding strategies.
|
|
|
|
|
|
|
|
TODO: remove the original strategy_generator.py after refactoring
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, device_mesh: DeviceMesh):
|
|
|
|
self.device_mesh = device_mesh
|
|
|
|
|
2022-09-21 04:23:21 +00:00
|
|
|
def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
|
|
|
"""
|
|
|
|
Compute the communication cost involved in the forward and backward iteration.
|
|
|
|
"""
|
|
|
|
|
|
|
|
comm_cost = TrainCycleItem(fwd=0, bwd=0)
|
|
|
|
|
|
|
|
def _compute_and_add(data: OperationData, action: CommunicationAction):
|
|
|
|
sharded_shape = strategy.sharding_specs[data].get_sharded_shape_per_device()
|
|
|
|
dtype = operand.data.dtype
|
|
|
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
|
|
|
num_bytes = size_per_elem_bytes * reduce(operator.mul, sharded_shape)
|
|
|
|
cost = self.device_mesh.all_reduce_cost(num_bytes=num_bytes, mesh_dim=action.mesh_dim)
|
|
|
|
|
|
|
|
# compute the fwd
|
|
|
|
if action.type == CommunicationType.FWD_ALL_REDUCE:
|
|
|
|
comm_cost.fwd += cost
|
|
|
|
elif action.type == CommunicationType.BWD_ALL_REDUCE:
|
|
|
|
comm_cost.fwd += cost
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Found unknown CommunicationType {action.type}")
|
|
|
|
|
|
|
|
# check if communication action exists
|
|
|
|
# if so, loop over each action and compute the cost of each action
|
|
|
|
if strategy.communication_actions is not None:
|
|
|
|
for operand, actions in strategy.communication_actions:
|
|
|
|
for action in actions:
|
|
|
|
_compute_and_add(operand, action)
|
|
|
|
|
|
|
|
# update the communication cost attribute in-place
|
|
|
|
strategy.communication_cost = comm_cost
|
|
|
|
return strategy
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
|
|
|
"""
|
|
|
|
Customize this method to compute the computation flops.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
|
|
|
"""
|
|
|
|
Customize this method to compute the memory cost in bytes.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
2022-09-20 06:17:21 +00:00
|
|
|
@abstractmethod
|
2022-09-21 04:23:21 +00:00
|
|
|
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
|
2022-09-20 06:17:21 +00:00
|
|
|
"""
|
2022-09-21 04:23:21 +00:00
|
|
|
Generate all possible sharding strategies for this operation.
|
2022-09-20 06:17:21 +00:00
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def validate(self, *args, **kwargs) -> bool:
|
|
|
|
"""
|
|
|
|
Validate if the operands are of desired shape.
|
|
|
|
If True, means this generator can be used for the current operation.
|
|
|
|
"""
|
|
|
|
pass
|
2022-09-20 03:20:54 +00:00
|
|
|
|
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
class StrategiesVector(list):
|
2022-08-19 06:57:23 +00:00
|
|
|
'''
|
|
|
|
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
|
|
|
strategies of the node.
|
|
|
|
|
|
|
|
Argument:
|
2022-08-23 06:23:08 +00:00
|
|
|
node (Node): node for which the list of sharding strategies are generated.
|
2022-08-19 06:57:23 +00:00
|
|
|
'''
|
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
def __init__(self, node: Node):
|
|
|
|
super().__init__()
|
2022-08-19 06:57:23 +00:00
|
|
|
self.node = node
|
2022-08-23 06:23:08 +00:00
|
|
|
# fetch its input and output nodes
|
2022-08-30 08:32:09 +00:00
|
|
|
# TODO: placeholder input nodes
|
2022-08-23 06:23:08 +00:00
|
|
|
self.predecessor_nodes = list(node._input_nodes.keys())
|
2022-08-25 09:19:59 +00:00
|
|
|
self.successor_nodes = list(node.users.keys())
|
2022-08-19 06:57:23 +00:00
|
|
|
|
|
|
|
def check_merge(self):
|
2022-08-30 08:32:09 +00:00
|
|
|
merge_label = False
|
|
|
|
if self.node.op == 'call_module':
|
|
|
|
target = self.node.target
|
|
|
|
root_module = self.node.graph.owning_module
|
|
|
|
submod = root_module.get_submodule(target)
|
|
|
|
submod_type = type(submod)
|
2022-09-14 02:25:45 +00:00
|
|
|
# merge elementwise module node into source nodes
|
|
|
|
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
2022-08-30 08:32:09 +00:00
|
|
|
if submod_type in ELEMENTWISE_MODULE_OP:
|
|
|
|
merge_label = True
|
|
|
|
|
|
|
|
if self.node.op == 'call_function':
|
2022-09-14 02:25:45 +00:00
|
|
|
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
2022-08-30 08:32:09 +00:00
|
|
|
if self.node.target in ELEMENTWISE_FUNC_OP:
|
|
|
|
merge_label = True
|
2022-09-16 03:33:01 +00:00
|
|
|
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
|
|
|
|
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
|
|
|
|
merge_label = True
|
2022-09-14 02:25:45 +00:00
|
|
|
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
|
|
|
|
if self.node.target in RESHAPE_FUNC_OP:
|
|
|
|
merge_label = True
|
2022-08-30 08:32:09 +00:00
|
|
|
|
|
|
|
return merge_label
|