mirror of https://github.com/hpcaitech/ColossalAI
146 lines
7.0 KiB
Python
146 lines
7.0 KiB
Python
![]() |
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from .node_handler import ModuleHandler, NodeHandler
|
||
|
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
|
||
|
from ..strategy import ConvStrategyGenerator, StrategyGenerator_V2
|
||
|
from typing import List, Dict
|
||
|
from .registry import operator_registry
|
||
|
|
||
|
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
|
||
|
|
||
|
|
||
|
@operator_registry.register(torch.nn.Conv1d)
|
||
|
@operator_registry.register(torch.nn.Conv2d)
|
||
|
@operator_registry.register(torch.nn.Conv3d)
|
||
|
class ConvModuleHandler(ModuleHandler):
|
||
|
"""
|
||
|
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
|
||
|
"""
|
||
|
|
||
|
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||
|
op_data_mapping = self.get_operation_data_mapping()
|
||
|
generators = []
|
||
|
generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh))
|
||
|
return generators
|
||
|
|
||
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||
|
# use transposed shape for strategies
|
||
|
# the strategies will be transformed back to its original shape in self.post_process
|
||
|
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||
|
type=OperationDataType.ARG,
|
||
|
data=self.node.args[0]._meta_data)
|
||
|
logical_shape_for_weight = list(self.named_parameters["weight"].shape)
|
||
|
logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
|
||
|
1], logical_shape_for_weight[0]
|
||
|
physical_other_operand = OperationData(name="weight",
|
||
|
type=OperationDataType.PARAM,
|
||
|
data=self.named_parameters['weight'],
|
||
|
logical_shape=torch.Size(logical_shape_for_weight))
|
||
|
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||
|
|
||
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||
|
|
||
|
if self.named_parameters['bias'] is not None:
|
||
|
physical_bias_operand = OperationData(name="bias",
|
||
|
type=OperationDataType.PARAM,
|
||
|
data=self.named_parameters['bias'])
|
||
|
mapping['bias'] = physical_bias_operand
|
||
|
return mapping
|
||
|
|
||
|
def post_process(self, strategy: ShardingStrategy_V2):
|
||
|
"""
|
||
|
Convert the sharding spec of the weight parameter back to its original shape.
|
||
|
"""
|
||
|
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||
|
if op_data.name == "weight":
|
||
|
assert op_data.logical_shape != op_data.data.shape
|
||
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
||
|
|
||
|
# switch first and second dim of the conv module weight
|
||
|
first_dim_partition = dim_partition_dict.pop(1, None)
|
||
|
second_dim_partition = dim_partition_dict.pop(0, None)
|
||
|
|
||
|
if first_dim_partition:
|
||
|
dim_partition_dict[0] = first_dim_partition
|
||
|
|
||
|
if second_dim_partition:
|
||
|
dim_partition_dict[1] = second_dim_partition
|
||
|
|
||
|
# re-init the sharding spec
|
||
|
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||
|
return strategy
|
||
|
|
||
|
|
||
|
@operator_registry.register(F.conv1d)
|
||
|
@operator_registry.register(F.conv2d)
|
||
|
@operator_registry.register(F.conv3d)
|
||
|
class ConvFunctionHandler(NodeHandler):
|
||
|
"""
|
||
|
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
|
||
|
"""
|
||
|
|
||
|
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||
|
op_data_mapping = self.get_operation_data_mapping()
|
||
|
generators = []
|
||
|
generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh))
|
||
|
return generators
|
||
|
|
||
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||
|
# use transposed shape for strategies
|
||
|
# the strategies will be transformed back to its original shape in self.post_process
|
||
|
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||
|
type=OperationDataType.ARG,
|
||
|
data=self.node.args[0]._meta_data)
|
||
|
|
||
|
# check if the other operand is a parameter
|
||
|
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
|
||
|
data_type = OperationDataType.PARAM
|
||
|
else:
|
||
|
data_type = OperationDataType.ARG
|
||
|
|
||
|
logical_shape_for_weight = list(self.node.args[1]._meta_data.shape)
|
||
|
logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
|
||
|
1], logical_shape_for_weight[0]
|
||
|
physical_other_operand = OperationData(name=str(self.node.args[1]),
|
||
|
type=data_type,
|
||
|
data=self.node.args[1]._meta_data,
|
||
|
logical_shape=torch.Size(logical_shape_for_weight))
|
||
|
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||
|
|
||
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||
|
|
||
|
if "bias" in self.node.kwargs:
|
||
|
# check if the other operand is a parameter
|
||
|
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
|
||
|
data_type = OperationDataType.PARAM
|
||
|
else:
|
||
|
data_type = OperationDataType.ARG
|
||
|
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
|
||
|
type=data_type,
|
||
|
data=self.node.kwargs["bias"]._meta_data)
|
||
|
mapping['bias'] = physical_bias_operand
|
||
|
return mapping
|
||
|
|
||
|
def post_process(self, strategy: ShardingStrategy_V2):
|
||
|
"""
|
||
|
Convert the sharding spec of the weight parameter back to its original shape.
|
||
|
"""
|
||
|
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||
|
if op_data.name == str(self.node.args[1]):
|
||
|
assert op_data.logical_shape != op_data.data.shape
|
||
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
||
|
|
||
|
# switch first and second dim of the conv function weight
|
||
|
first_dim_partition = dim_partition_dict.pop(1, None)
|
||
|
second_dim_partition = dim_partition_dict.pop(0, None)
|
||
|
|
||
|
if first_dim_partition:
|
||
|
dim_partition_dict[0] = first_dim_partition
|
||
|
|
||
|
if second_dim_partition:
|
||
|
dim_partition_dict[1] = second_dim_partition
|
||
|
|
||
|
# re-init the sharding spec
|
||
|
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||
|
return strategy
|