[autoparallel] added new linear module handler (#1616)

pull/1625/head
Frank Lee 2022-09-21 12:23:21 +08:00 committed by GitHub
parent 170fa81095
commit d925122020
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 393 additions and 18 deletions

View File

@ -0,0 +1,139 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, StrategyGenerator_V2, OperationDataType, OperationData
from typing import List, Dict
from .registry import operator_registry
__all__ = ['LinearModuleHandler']
class DotProductStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass
class MatVecStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass
class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""TODO: to be implemented"""
pass
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""TODO: to be implemented"""
pass
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
"""TODO: to be implemented"""
pass
def validate(self, *args, **kwargs) -> bool:
"""TODO: to be implemented"""
pass
class BatchedMatMulStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass
@operator_registry.register(torch.nn.Linear)
class LinearModuleHandler(ModuleHandler):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
generators = []
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'],
logical_shape=self.named_parameters['weight'].shape[::-1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.named_parameters['bias'] is not None:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
mapping['bias'] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy_V2):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
return strategy
@operator_registry.register(F.linear)
class LinearFunctionHandler(NodeHandler):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
generators = []
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.node.args[1]._meta_data,
logical_shape=self.node.args[1]._meta_data.shape[::-1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.node.args[2] is not None:
physical_bias_operand = OperationData(name=str(self.node.args[2]),
type=OperationDataType.ARG,
data=self.node.args[2]._meta_data)
mapping['bias'] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy_V2):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
return strategy

View File

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List
from ..sharding_strategy import StrategiesVector, Operand, StrategyGenerator_V2
from ..sharding_strategy import ShardingStrategy, ShardingStrategy_V2, StrategiesVector, OperationData, StrategyGenerator_V2
class NodeHandler(ABC):
@ -36,8 +36,15 @@ class NodeHandler(ABC):
for generator in self.strategy_generator:
strategies = generator.generate(operand_mapping)
self.strategies_vector.extend(strategies)
self.strategies_vector = map(self.post_process, self.strategies_vector)
return self.strategies_vector
def post_process(self, strategy: ShardingStrategy_V2):
# tranform the strategy generated
# e.g. to process the sharding strategy for the transposed weights
return strategy
@abstractmethod
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
"""
@ -46,21 +53,40 @@ class NodeHandler(ABC):
pass
@abstractmethod
def get_operand_mapping(self) -> Dict[str, Operand]:
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
"""
Returns the mapping between the logical operand name to its physical operands.
A logical operand is defined by the strategy generator, for example, a matrix multiplication
operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is
the module input and the physical operand for "other" is the module weight.
Returns the mapping between the logical operation data to its physical data.
A logical operation data is a data associated with an operation, which can be input and output. It is
defined by the strategy generator, for example, a matrix multiplication operation has two operands "input"
and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is
the module input, the physical operand for "other" is the module weight, and the physical result for "output"
is the module output.
Note that the operand name is specified by the StrategyGenerator object.
For example:
# for a linear layer
mapping = {
"input": Operand(name=str(self.node.args[0]), type=OperandType.ARG),
"other": Operand(name="weight", type=OperandType.PARAM),
"bias": Operand(name="bias", type=OperandType.PARAM)
"input": Operand(name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data),
"other": Operand(name="weight", type=OperationDataType.PARAM, data=self.named_parameters['weight']),
"bias": Operand(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']),
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
}
"""
pass
class ModuleHandler(NodeHandler):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# set attributes to access module parameters for convenience
assert self.node.graph.owning_module is not None, \
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
module = self.node.graph.owning_module.get_submodule(self.node.target)
named_parameters = list(module.named_parameters(recurse=False))
# convert named parameters from list to dict
named_parameters = {k: v for k, v in named_parameters}
self.module = module
self.named_parameters = named_parameters

View File

@ -1,6 +1,10 @@
from dataclasses import dataclass
from abc import ABC, abstractmethod
from enum import Enum
import operator
import torch
from functools import reduce
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from typing import Dict, List, Union, Tuple, Any
@ -40,18 +44,35 @@ class ShardingStrategy:
input_shardings: List[ShardingSpec] = None
class OperandType(Enum):
class OperationDataType(Enum):
"""
An operand can come from the argument list of an operator or the parameter list of a module.
An operation can come from the argument list of an operator or the parameter list of a module.
"""
ARG = 0
PARAM = 1
OUTPUT = 2
@dataclass
class Operand:
class OperationData:
"""
OperationData is the data related to an operator, the data can be the operand or the output.
Args:
name (str): the name of the operation-related data
type (OperationDataType): the type of the operation data
data (torch.Tensor): the value for this data, usually it is a meta tensor.
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
"""
name: str
type: OperandType
type: OperationDataType
data: torch.Tensor
logical_shape: Tuple[int] = None
def __post_init__(self):
# if no logical shape is specified, use the data shape as the logical shape
if self.logical_shape is None:
self.logical_shape = self.data.shape
@dataclass
@ -69,6 +90,20 @@ class TrainCycleItem:
total: Any
class CommunicationType(Enum):
FWD_ALL_REDUCE = 0
BWD_ALL_REDUCE = 1
@dataclass
class CommunicationAction:
"""
The actions
"""
type: CommunicationType
mesh_dim: int
@dataclass
class ShardingStrategy_V2:
"""
@ -86,12 +121,35 @@ class ShardingStrategy_V2:
strategy.(default to None)
"""
name: str
output_sharding_spec: ShardingSpec
sharding_specs: Dict[OperationData, ShardingSpec] = None
compute_cost: TrainCycleItem = None
communication_cost: TrainCycleItem = None
memory_cost: TrainCycleItem = None
input_sharding_specs: Dict[Operand, ShardingSpec] = None
input_resharding_costs: Dict[Operand, List[float]] = None
input_resharding_costs: Dict[OperationData, List[float]] = None
communication_actions: Dict[OperationData, List[CommunicationAction]] = None
@property
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
specs = {}
specs.update(self._get_sharding_spec(OperationDataType.ARG))
specs.update(self._get_sharding_spec(OperationDataType.PARAM))
return specs
@property
def argument_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.ARG)
@property
def param_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.PARAM)
@property
def output_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.OUTPUT)
def _get_sharding_spec(self, operation_data_type: OperationDataType):
specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type}
return specs
class StrategyGenerator_V2(ABC):
@ -104,9 +162,57 @@ class StrategyGenerator_V2(ABC):
def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
@abstractmethod
def generate(self, operand_mapping: Dict[str, Operand]) -> List[ShardingStrategy_V2]:
def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Compute the communication cost involved in the forward and backward iteration.
"""
comm_cost = TrainCycleItem(fwd=0, bwd=0)
def _compute_and_add(data: OperationData, action: CommunicationAction):
sharded_shape = strategy.sharding_specs[data].get_sharded_shape_per_device()
dtype = operand.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
num_bytes = size_per_elem_bytes * reduce(operator.mul, sharded_shape)
cost = self.device_mesh.all_reduce_cost(num_bytes=num_bytes, mesh_dim=action.mesh_dim)
# compute the fwd
if action.type == CommunicationType.FWD_ALL_REDUCE:
comm_cost.fwd += cost
elif action.type == CommunicationType.BWD_ALL_REDUCE:
comm_cost.fwd += cost
else:
raise ValueError(f"Found unknown CommunicationType {action.type}")
# check if communication action exists
# if so, loop over each action and compute the cost of each action
if strategy.communication_actions is not None:
for operand, actions in strategy.communication_actions:
for action in actions:
_compute_and_add(operand, action)
# update the communication cost attribute in-place
strategy.communication_cost = comm_cost
return strategy
@abstractmethod
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Customize this method to compute the computation flops.
"""
pass
@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Customize this method to compute the memory cost in bytes.
"""
pass
@abstractmethod
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
"""
Generate all possible sharding strategies for this operation.
"""
pass

View File

@ -0,0 +1,104 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
def test_linear_module_handler():
model = nn.Sequential(nn.Linear(10, 20).to('meta'))
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
print(graph)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
linear_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(linear_mod_node)
# build handler
handler = LinearModuleHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 10])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 10])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 20])
assert mapping['output'].type == OperationDataType.OUTPUT
def test_linear_function_handler():
model = nn.Linear(10, 20).to('meta')
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
print(graph)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
linear_func_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_func_node)
# build handler
handler = LinearFunctionHandler(node=linear_func_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# # check operation data mapping
mapping = handler.get_operation_data_mapping()
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 10])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 10])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10])
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20])
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['output'].name == "linear"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 20])
assert mapping['output'].type == OperationDataType.OUTPUT
if __name__ == '__main__':
# test_linear_module_handler()
test_linear_function_handler()