2022-09-21 04:23:21 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
2022-10-12 03:16:18 +00:00
|
|
|
from colossalai.tensor.sharding_spec import ShardingException
|
2022-09-21 04:23:21 +00:00
|
|
|
from .node_handler import ModuleHandler, NodeHandler
|
2022-09-26 08:58:14 +00:00
|
|
|
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
|
2022-09-28 03:32:16 +00:00
|
|
|
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2, BatchedMatMulStrategyGenerator
|
2022-10-12 03:16:18 +00:00
|
|
|
from typing import List, Dict, Union
|
2022-09-21 04:23:21 +00:00
|
|
|
from .registry import operator_registry
|
2022-10-12 03:16:18 +00:00
|
|
|
from copy import deepcopy
|
|
|
|
from .utils import switch_partition_dim, update_partition_dim
|
2022-09-21 04:23:21 +00:00
|
|
|
|
2022-09-28 03:32:16 +00:00
|
|
|
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
|
2022-09-21 04:23:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
@operator_registry.register(torch.nn.Linear)
|
|
|
|
class LinearModuleHandler(ModuleHandler):
|
|
|
|
"""
|
|
|
|
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
|
|
|
"""
|
|
|
|
|
2022-09-26 08:58:14 +00:00
|
|
|
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
|
|
|
op_data_mapping = self.get_operation_data_mapping()
|
2022-09-21 04:23:21 +00:00
|
|
|
generators = []
|
2022-09-26 08:58:14 +00:00
|
|
|
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
|
2022-09-21 04:23:21 +00:00
|
|
|
return generators
|
|
|
|
|
|
|
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
|
|
|
# use transposed shape for strategies
|
|
|
|
# the strategies will be transformed back to its original shape in self.post_process
|
2022-10-12 03:16:18 +00:00
|
|
|
input_meta_data = self.node.args[0]._meta_data
|
|
|
|
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
|
2022-09-21 04:23:21 +00:00
|
|
|
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
|
|
|
type=OperationDataType.ARG,
|
2022-10-12 03:16:18 +00:00
|
|
|
data=input_meta_data,
|
|
|
|
logical_shape=input_logical_shape)
|
2022-09-21 04:23:21 +00:00
|
|
|
physical_other_operand = OperationData(name="weight",
|
|
|
|
type=OperationDataType.PARAM,
|
|
|
|
data=self.named_parameters['weight'],
|
|
|
|
logical_shape=self.named_parameters['weight'].shape[::-1])
|
2022-10-12 03:16:18 +00:00
|
|
|
output_meta_data = self.node._meta_data
|
|
|
|
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
|
|
|
|
physical_output = OperationData(name=str(self.node),
|
|
|
|
type=OperationDataType.OUTPUT,
|
|
|
|
data=output_meta_data,
|
|
|
|
logical_shape=output_logical_shape)
|
2022-09-21 04:23:21 +00:00
|
|
|
|
|
|
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
|
|
|
|
|
|
|
if self.named_parameters['bias'] is not None:
|
|
|
|
physical_bias_operand = OperationData(name="bias",
|
|
|
|
type=OperationDataType.PARAM,
|
|
|
|
data=self.named_parameters['bias'])
|
|
|
|
mapping['bias'] = physical_bias_operand
|
|
|
|
return mapping
|
|
|
|
|
2022-10-12 03:16:18 +00:00
|
|
|
def post_process(self, strategy: ShardingStrategy_V2) -> Union[ShardingStrategy_V2, List[ShardingStrategy_V2]]:
|
2022-09-21 04:23:21 +00:00
|
|
|
"""
|
2022-10-12 03:16:18 +00:00
|
|
|
Convert the sharding spec from the logical shape to the physical shape.
|
2022-09-21 04:23:21 +00:00
|
|
|
"""
|
2022-10-12 03:16:18 +00:00
|
|
|
# switch the dimensions of the transposed weight
|
2022-09-21 04:23:21 +00:00
|
|
|
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
|
|
|
if op_data.name == "weight":
|
|
|
|
assert op_data.logical_shape != op_data.data.shape
|
2022-10-12 03:16:18 +00:00
|
|
|
switch_partition_dim(sharding_spec, 0, -1)
|
|
|
|
|
|
|
|
# create multiple sharding strategies for the inputs
|
|
|
|
# as input can be multi-dimensinal and the partition dim is only 2D,
|
|
|
|
# we need to map the partition at dim 0 to one of the first few dimensions of the input
|
|
|
|
sharding_strategies = []
|
|
|
|
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
|
|
|
|
output_op_data = strategy.get_op_data_by_name(str(self.node))
|
|
|
|
num_input_dims = input_op_data.data.dim()
|
|
|
|
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
|
|
|
|
|
|
|
|
if 0 in input_sharding_spec.dim_partition_dict:
|
|
|
|
for i in range(num_input_dims - 1):
|
|
|
|
new_strategy = strategy.clone()
|
|
|
|
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
|
|
|
|
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
|
|
|
|
try:
|
|
|
|
update_partition_dim(sharding_spec=input_sharding_spec,
|
|
|
|
dim_mapping={0: i},
|
|
|
|
physical_shape=input_op_data.data.shape,
|
|
|
|
inplace=True)
|
|
|
|
update_partition_dim(sharding_spec=output_sharding_spec,
|
|
|
|
dim_mapping={0: i},
|
|
|
|
physical_shape=output_op_data.data.shape,
|
|
|
|
inplace=True)
|
|
|
|
sharding_strategies.append(new_strategy)
|
|
|
|
except ShardingException:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
sharding_strategies.append(strategy)
|
2022-09-21 04:23:21 +00:00
|
|
|
|
2022-10-12 03:16:18 +00:00
|
|
|
return sharding_strategies
|
2022-09-21 04:23:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
@operator_registry.register(F.linear)
|
|
|
|
class LinearFunctionHandler(NodeHandler):
|
|
|
|
"""
|
|
|
|
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
|
|
|
"""
|
|
|
|
|
2022-09-26 08:58:14 +00:00
|
|
|
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
|
|
|
op_data_mapping = self.get_operation_data_mapping()
|
2022-09-21 04:23:21 +00:00
|
|
|
generators = []
|
2022-09-26 08:58:14 +00:00
|
|
|
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
|
2022-09-21 04:23:21 +00:00
|
|
|
return generators
|
|
|
|
|
|
|
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
|
|
|
# use transposed shape for strategies
|
|
|
|
# the strategies will be transformed back to its original shape in self.post_process
|
|
|
|
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
|
|
|
type=OperationDataType.ARG,
|
|
|
|
data=self.node.args[0]._meta_data)
|
2022-09-26 08:58:14 +00:00
|
|
|
|
|
|
|
# check if the other operand is a parameter
|
|
|
|
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
|
|
|
|
data_type = OperationDataType.PARAM
|
|
|
|
else:
|
|
|
|
data_type = OperationDataType.ARG
|
|
|
|
|
2022-09-21 04:23:21 +00:00
|
|
|
physical_other_operand = OperationData(name=str(self.node.args[1]),
|
2022-09-26 08:58:14 +00:00
|
|
|
type=data_type,
|
2022-09-21 04:23:21 +00:00
|
|
|
data=self.node.args[1]._meta_data,
|
|
|
|
logical_shape=self.node.args[1]._meta_data.shape[::-1])
|
|
|
|
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
|
|
|
|
|
|
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
|
|
|
|
|
|
|
if self.node.args[2] is not None:
|
2022-09-26 08:58:14 +00:00
|
|
|
# check if the other operand is a parameter
|
|
|
|
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter):
|
|
|
|
data_type = OperationDataType.PARAM
|
|
|
|
else:
|
|
|
|
data_type = OperationDataType.ARG
|
2022-09-21 04:23:21 +00:00
|
|
|
physical_bias_operand = OperationData(name=str(self.node.args[2]),
|
2022-09-26 08:58:14 +00:00
|
|
|
type=data_type,
|
2022-09-21 04:23:21 +00:00
|
|
|
data=self.node.args[2]._meta_data)
|
|
|
|
mapping['bias'] = physical_bias_operand
|
|
|
|
return mapping
|
|
|
|
|
|
|
|
def post_process(self, strategy: ShardingStrategy_V2):
|
|
|
|
"""
|
|
|
|
Convert the sharding spec of the weight parameter back to its original shape.
|
|
|
|
"""
|
|
|
|
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
|
|
|
if op_data.name == str(self.node.args[1]):
|
|
|
|
assert op_data.logical_shape != op_data.data.shape
|
2022-10-12 03:16:18 +00:00
|
|
|
switch_partition_dim(sharding_spec, 0, -1)
|
|
|
|
|
|
|
|
# create multiple sharding strategies for the inputs
|
|
|
|
# as input can be multi-dimensinal and the partition dim is only 2D,
|
|
|
|
# we need to map the partition at dim 0 to one of the first few dimensions of the input
|
|
|
|
sharding_strategies = []
|
|
|
|
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
|
|
|
|
output_op_data = strategy.get_op_data_by_name(str(self.node))
|
|
|
|
num_input_dims = input_op_data.data.dim()
|
|
|
|
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
|
|
|
|
|
|
|
|
if 0 in input_sharding_spec.dim_partition_dict:
|
|
|
|
for i in range(num_input_dims - 1):
|
|
|
|
new_strategy = strategy.clone()
|
|
|
|
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
|
|
|
|
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
|
|
|
|
try:
|
|
|
|
update_partition_dim(sharding_spec=input_sharding_spec,
|
|
|
|
dim_mapping={0: i},
|
|
|
|
physical_shape=input_op_data.data.shape,
|
|
|
|
inplace=True)
|
|
|
|
update_partition_dim(sharding_spec=output_sharding_spec,
|
|
|
|
dim_mapping={0: i},
|
|
|
|
physical_shape=output_op_data.data.shape,
|
|
|
|
inplace=True)
|
|
|
|
sharding_strategies.append(new_strategy)
|
|
|
|
except ShardingException:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
sharding_strategies.append(strategy)
|
2022-09-21 04:23:21 +00:00
|
|
|
|
|
|
|
return strategy
|
2022-09-28 03:32:16 +00:00
|
|
|
|
|
|
|
|
|
|
|
@operator_registry.register(torch.bmm)
|
|
|
|
@operator_registry.register(torch.Tensor.bmm)
|
|
|
|
class BMMFunctionHandler(NodeHandler):
|
|
|
|
|
|
|
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
|
|
|
# use transposed shape for strategies
|
|
|
|
# the strategies will be transformed back to its original shape in self.post_process
|
|
|
|
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
|
|
|
type=OperationDataType.ARG,
|
|
|
|
data=self.node.args[0]._meta_data)
|
|
|
|
|
|
|
|
physical_other_operand = OperationData(name=str(self.node.args[1]),
|
|
|
|
type=OperationDataType.ARG,
|
|
|
|
data=self.node.args[1]._meta_data)
|
|
|
|
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
|
|
|
|
|
|
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
|
|
|
return mapping
|
|
|
|
|
|
|
|
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
|
|
|
generators = []
|
|
|
|
op_data_mapping = self.get_operation_data_mapping()
|
|
|
|
generators = []
|
|
|
|
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
|
|
|
|
return generators
|