mirror of https://github.com/hpcaitech/ColossalAI
40 lines
1.8 KiB
Python
40 lines
1.8 KiB
Python
![]() |
import torch
|
||
|
from .node_handler import NodeHandler
|
||
|
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector
|
||
|
from ..strategy import TensorStrategyGenerator, TensorTupleStrategyGenerator, StrategyGenerator_V2
|
||
|
from typing import List, Dict
|
||
|
from .registry import operator_registry
|
||
|
import operator
|
||
|
|
||
|
__all__ = ['GetItemHandler']
|
||
|
|
||
|
|
||
|
@operator_registry.register(operator.getitem)
|
||
|
class GetItemHandler(NodeHandler):
|
||
|
"""
|
||
|
A GetItemHandler which deals with the sharding strategies for operator.getitem.
|
||
|
"""
|
||
|
|
||
|
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||
|
op_data_mapping = self.get_operation_data_mapping()
|
||
|
generators = []
|
||
|
if isinstance(op_data_mapping["input"].data, torch.Tensor):
|
||
|
generators.append(TensorStrategyGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||
|
else:
|
||
|
generators.append(TensorTupleStrategyGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||
|
|
||
|
return generators
|
||
|
|
||
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||
|
# use transposed shape for strategies
|
||
|
# the strategies will be transformed back to its original shape in self.post_process
|
||
|
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||
|
type=OperationDataType.ARG,
|
||
|
data=self.node.args[0]._meta_data)
|
||
|
physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1])
|
||
|
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||
|
|
||
|
mapping = {"input": physical_input_operand, "index": physical_other_operand, "output": physical_output}
|
||
|
|
||
|
return mapping
|