mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added new linear module handler (#1616)
parent
170fa81095
commit
d925122020
|
@ -0,0 +1,139 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy_V2, StrategyGenerator_V2, OperationDataType, OperationData
|
||||
from typing import List, Dict
|
||||
from .registry import operator_registry
|
||||
|
||||
__all__ = ['LinearModuleHandler']
|
||||
|
||||
|
||||
class DotProductStrategyGenerator(StrategyGenerator_V2):
|
||||
"""TODO: to be implemented"""
|
||||
pass
|
||||
|
||||
|
||||
class MatVecStrategyGenerator(StrategyGenerator_V2):
|
||||
"""TODO: to be implemented"""
|
||||
pass
|
||||
|
||||
|
||||
class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||
|
||||
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||
"""TODO: to be implemented"""
|
||||
pass
|
||||
|
||||
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||
"""TODO: to be implemented"""
|
||||
pass
|
||||
|
||||
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
|
||||
"""TODO: to be implemented"""
|
||||
pass
|
||||
|
||||
def validate(self, *args, **kwargs) -> bool:
|
||||
"""TODO: to be implemented"""
|
||||
pass
|
||||
|
||||
|
||||
class BatchedMatMulStrategyGenerator(StrategyGenerator_V2):
|
||||
"""TODO: to be implemented"""
|
||||
pass
|
||||
|
||||
|
||||
@operator_registry.register(torch.nn.Linear)
|
||||
class LinearModuleHandler(ModuleHandler):
|
||||
"""
|
||||
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||
"""
|
||||
|
||||
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||
generators = []
|
||||
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[0]._meta_data)
|
||||
physical_other_operand = OperationData(name="weight",
|
||||
type=OperationDataType.PARAM,
|
||||
data=self.named_parameters['weight'],
|
||||
logical_shape=self.named_parameters['weight'].shape[::-1])
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
|
||||
if self.named_parameters['bias'] is not None:
|
||||
physical_bias_operand = OperationData(name="bias",
|
||||
type=OperationDataType.PARAM,
|
||||
data=self.named_parameters['bias'])
|
||||
mapping['bias'] = physical_bias_operand
|
||||
return mapping
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy_V2):
|
||||
"""
|
||||
Convert the sharding spec of the weight parameter back to its original shape.
|
||||
"""
|
||||
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||
if op_data.name == "weight":
|
||||
assert op_data.logical_shape != op_data.data.shape
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
# switch first and last dim of the linear module weight
|
||||
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
|
||||
|
||||
# re-init the sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||
return strategy
|
||||
|
||||
|
||||
@operator_registry.register(F.linear)
|
||||
class LinearFunctionHandler(NodeHandler):
|
||||
"""
|
||||
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||
"""
|
||||
|
||||
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||
generators = []
|
||||
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[0]._meta_data)
|
||||
physical_other_operand = OperationData(name=str(self.node.args[1]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[1]._meta_data,
|
||||
logical_shape=self.node.args[1]._meta_data.shape[::-1])
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
|
||||
if self.node.args[2] is not None:
|
||||
physical_bias_operand = OperationData(name=str(self.node.args[2]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[2]._meta_data)
|
||||
mapping['bias'] = physical_bias_operand
|
||||
return mapping
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy_V2):
|
||||
"""
|
||||
Convert the sharding spec of the weight parameter back to its original shape.
|
||||
"""
|
||||
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||
if op_data.name == str(self.node.args[1]):
|
||||
assert op_data.logical_shape != op_data.data.shape
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
# switch first and last dim of the linear module weight
|
||||
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
|
||||
|
||||
# re-init the sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||
return strategy
|
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||
from torch.fx.node import Node
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from typing import Dict, List
|
||||
from ..sharding_strategy import StrategiesVector, Operand, StrategyGenerator_V2
|
||||
from ..sharding_strategy import ShardingStrategy, ShardingStrategy_V2, StrategiesVector, OperationData, StrategyGenerator_V2
|
||||
|
||||
|
||||
class NodeHandler(ABC):
|
||||
|
@ -36,8 +36,15 @@ class NodeHandler(ABC):
|
|||
for generator in self.strategy_generator:
|
||||
strategies = generator.generate(operand_mapping)
|
||||
self.strategies_vector.extend(strategies)
|
||||
|
||||
self.strategies_vector = map(self.post_process, self.strategies_vector)
|
||||
return self.strategies_vector
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy_V2):
|
||||
# tranform the strategy generated
|
||||
# e.g. to process the sharding strategy for the transposed weights
|
||||
return strategy
|
||||
|
||||
@abstractmethod
|
||||
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||
"""
|
||||
|
@ -46,21 +53,40 @@ class NodeHandler(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_operand_mapping(self) -> Dict[str, Operand]:
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
"""
|
||||
Returns the mapping between the logical operand name to its physical operands.
|
||||
A logical operand is defined by the strategy generator, for example, a matrix multiplication
|
||||
operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is
|
||||
the module input and the physical operand for "other" is the module weight.
|
||||
Returns the mapping between the logical operation data to its physical data.
|
||||
A logical operation data is a data associated with an operation, which can be input and output. It is
|
||||
defined by the strategy generator, for example, a matrix multiplication operation has two operands "input"
|
||||
and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is
|
||||
the module input, the physical operand for "other" is the module weight, and the physical result for "output"
|
||||
is the module output.
|
||||
Note that the operand name is specified by the StrategyGenerator object.
|
||||
|
||||
For example:
|
||||
|
||||
# for a linear layer
|
||||
mapping = {
|
||||
"input": Operand(name=str(self.node.args[0]), type=OperandType.ARG),
|
||||
"other": Operand(name="weight", type=OperandType.PARAM),
|
||||
"bias": Operand(name="bias", type=OperandType.PARAM)
|
||||
"input": Operand(name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data),
|
||||
"other": Operand(name="weight", type=OperationDataType.PARAM, data=self.named_parameters['weight']),
|
||||
"bias": Operand(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']),
|
||||
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModuleHandler(NodeHandler):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# set attributes to access module parameters for convenience
|
||||
assert self.node.graph.owning_module is not None, \
|
||||
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
|
||||
module = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||
named_parameters = list(module.named_parameters(recurse=False))
|
||||
# convert named parameters from list to dict
|
||||
named_parameters = {k: v for k, v in named_parameters}
|
||||
self.module = module
|
||||
self.named_parameters = named_parameters
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import operator
|
||||
import torch
|
||||
from functools import reduce
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from typing import Dict, List, Union, Tuple, Any
|
||||
|
@ -40,18 +44,35 @@ class ShardingStrategy:
|
|||
input_shardings: List[ShardingSpec] = None
|
||||
|
||||
|
||||
class OperandType(Enum):
|
||||
class OperationDataType(Enum):
|
||||
"""
|
||||
An operand can come from the argument list of an operator or the parameter list of a module.
|
||||
An operation can come from the argument list of an operator or the parameter list of a module.
|
||||
"""
|
||||
ARG = 0
|
||||
PARAM = 1
|
||||
OUTPUT = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class Operand:
|
||||
class OperationData:
|
||||
"""
|
||||
OperationData is the data related to an operator, the data can be the operand or the output.
|
||||
|
||||
Args:
|
||||
name (str): the name of the operation-related data
|
||||
type (OperationDataType): the type of the operation data
|
||||
data (torch.Tensor): the value for this data, usually it is a meta tensor.
|
||||
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
|
||||
"""
|
||||
name: str
|
||||
type: OperandType
|
||||
type: OperationDataType
|
||||
data: torch.Tensor
|
||||
logical_shape: Tuple[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# if no logical shape is specified, use the data shape as the logical shape
|
||||
if self.logical_shape is None:
|
||||
self.logical_shape = self.data.shape
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -69,6 +90,20 @@ class TrainCycleItem:
|
|||
total: Any
|
||||
|
||||
|
||||
class CommunicationType(Enum):
|
||||
FWD_ALL_REDUCE = 0
|
||||
BWD_ALL_REDUCE = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunicationAction:
|
||||
"""
|
||||
The actions
|
||||
"""
|
||||
type: CommunicationType
|
||||
mesh_dim: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardingStrategy_V2:
|
||||
"""
|
||||
|
@ -86,12 +121,35 @@ class ShardingStrategy_V2:
|
|||
strategy.(default to None)
|
||||
"""
|
||||
name: str
|
||||
output_sharding_spec: ShardingSpec
|
||||
sharding_specs: Dict[OperationData, ShardingSpec] = None
|
||||
compute_cost: TrainCycleItem = None
|
||||
communication_cost: TrainCycleItem = None
|
||||
memory_cost: TrainCycleItem = None
|
||||
input_sharding_specs: Dict[Operand, ShardingSpec] = None
|
||||
input_resharding_costs: Dict[Operand, List[float]] = None
|
||||
input_resharding_costs: Dict[OperationData, List[float]] = None
|
||||
communication_actions: Dict[OperationData, List[CommunicationAction]] = None
|
||||
|
||||
@property
|
||||
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
||||
specs = {}
|
||||
specs.update(self._get_sharding_spec(OperationDataType.ARG))
|
||||
specs.update(self._get_sharding_spec(OperationDataType.PARAM))
|
||||
return specs
|
||||
|
||||
@property
|
||||
def argument_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
||||
return self._get_sharding_spec(OperationDataType.ARG)
|
||||
|
||||
@property
|
||||
def param_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
||||
return self._get_sharding_spec(OperationDataType.PARAM)
|
||||
|
||||
@property
|
||||
def output_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
||||
return self._get_sharding_spec(OperationDataType.OUTPUT)
|
||||
|
||||
def _get_sharding_spec(self, operation_data_type: OperationDataType):
|
||||
specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type}
|
||||
return specs
|
||||
|
||||
|
||||
class StrategyGenerator_V2(ABC):
|
||||
|
@ -104,9 +162,57 @@ class StrategyGenerator_V2(ABC):
|
|||
def __init__(self, device_mesh: DeviceMesh):
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, operand_mapping: Dict[str, Operand]) -> List[ShardingStrategy_V2]:
|
||||
def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||
"""
|
||||
Compute the communication cost involved in the forward and backward iteration.
|
||||
"""
|
||||
|
||||
comm_cost = TrainCycleItem(fwd=0, bwd=0)
|
||||
|
||||
def _compute_and_add(data: OperationData, action: CommunicationAction):
|
||||
sharded_shape = strategy.sharding_specs[data].get_sharded_shape_per_device()
|
||||
dtype = operand.data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
num_bytes = size_per_elem_bytes * reduce(operator.mul, sharded_shape)
|
||||
cost = self.device_mesh.all_reduce_cost(num_bytes=num_bytes, mesh_dim=action.mesh_dim)
|
||||
|
||||
# compute the fwd
|
||||
if action.type == CommunicationType.FWD_ALL_REDUCE:
|
||||
comm_cost.fwd += cost
|
||||
elif action.type == CommunicationType.BWD_ALL_REDUCE:
|
||||
comm_cost.fwd += cost
|
||||
else:
|
||||
raise ValueError(f"Found unknown CommunicationType {action.type}")
|
||||
|
||||
# check if communication action exists
|
||||
# if so, loop over each action and compute the cost of each action
|
||||
if strategy.communication_actions is not None:
|
||||
for operand, actions in strategy.communication_actions:
|
||||
for action in actions:
|
||||
_compute_and_add(operand, action)
|
||||
|
||||
# update the communication cost attribute in-place
|
||||
strategy.communication_cost = comm_cost
|
||||
return strategy
|
||||
|
||||
@abstractmethod
|
||||
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||
"""
|
||||
Customize this method to compute the computation flops.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||
"""
|
||||
Customize this method to compute the memory cost in bytes.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
|
||||
"""
|
||||
Generate all possible sharding strategies for this operation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
def test_linear_module_handler():
|
||||
model = nn.Sequential(nn.Linear(10, 20).to('meta'))
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
print(graph)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
linear_mod_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
|
||||
# build handler
|
||||
handler = LinearModuleHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
for name, op_data in mapping.items():
|
||||
op_data: OperationData
|
||||
# make sure they have valid values
|
||||
assert op_data.logical_shape is not None
|
||||
assert op_data.data is not None
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 10])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 10])
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].data.shape == torch.Size([20, 10])
|
||||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([20])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||
|
||||
assert mapping['output'].name == "_0"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 20])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
|
||||
def test_linear_function_handler():
|
||||
model = nn.Linear(10, 20).to('meta')
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
print(graph)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
linear_func_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(linear_func_node)
|
||||
|
||||
# build handler
|
||||
handler = LinearFunctionHandler(node=linear_func_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
|
||||
# # check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 10])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 10])
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].data.shape == torch.Size([20, 10])
|
||||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([20])
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||
|
||||
assert mapping['output'].name == "linear"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 20])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_linear_module_handler()
|
||||
test_linear_function_handler()
|
Loading…
Reference in New Issue