ColossalAI/colossalai/auto_parallel/solver/strategy/getitem_generator.py

148 lines
6.4 KiB
Python

import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
from typing import List
from .._utils import exception_handler
import copy
__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']
class GetItemStrategyGenerator(FollowingStrategyGenerator):
"""
GetItemStrategyGenerator is a generic class to generate strategies for operator.getitem.
The operation data is defined as `output = input[other]`.
There are mainly three use cases:
1. args_0._meta_data: torch.Tensor, args_1._meta_data: int
2. args_0._meta_data: torch.Tensor, args_1._meta_data: slice
3. args_0._meta_data: Tuple[torch.Tensor], args_1._meta_data: int
"""
@property
def has_bias(self):
return 'bias' in self.op_data
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class TensorStrategyGenerator(GetItemStrategyGenerator):
'''
Deal with case 1 and 2.
'''
def generate(self):
strategy_list = []
for strategy in self.predecessor_node.strategies_vector:
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
gather_input = 0 in dim_partition_dict_for_input
if gather_input:
logical_process_axis = dim_partition_dict_for_output.pop(0)
shift_dim_partition_dict_for_output = {}
for dim, mesh_dim_list in dim_partition_dict_for_output.items():
shift_dim_partition_dict_for_output[dim - 1] = mesh_dim_list
dim_partition_dict_for_output = shift_dim_partition_dict_for_output
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
if gather_input:
input_communication_spec = self.get_communication_spec(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=logical_process_axis)
communication_action_mapping["input"] = input_communication_spec
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list
class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
'''
Deal with case 3.
'''
def generate(self):
strategy_list = []
index = self.op_data["index"].data
for strategy in self.predecessor_node.strategies_vector:
# the sharding spec for input in this case is a tuple of ShardingSpec.
sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_mapping = {
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
sharding_spec_mapping["input"] = sharding_spec_for_input
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list