mirror of https://github.com/hpcaitech/ColossalAI
YuliangLiu0306
2 years ago
committed by
GitHub
6 changed files with 349 additions and 4 deletions
@ -0,0 +1,55 @@ |
|||||||
|
from typing import Dict, List |
||||||
|
|
||||||
|
import torch |
||||||
|
|
||||||
|
from ..sharding_strategy import OperationData, OperationDataType |
||||||
|
from .node_handler import NodeHandler |
||||||
|
from .registry import operator_registry |
||||||
|
from .strategy import SoftmaxGenerator, StrategyGenerator |
||||||
|
|
||||||
|
__all__ = ['SoftmaxHandler'] |
||||||
|
|
||||||
|
|
||||||
|
@operator_registry.register(torch.nn.Softmax) |
||||||
|
@operator_registry.register(torch.nn.functional.softmax) |
||||||
|
class SoftmaxHandler(NodeHandler): |
||||||
|
""" |
||||||
|
A SoftmaxHandler which deals with the sharding strategies for |
||||||
|
torch.nn.Softmax or torch.nn.functional.softmax. |
||||||
|
""" |
||||||
|
|
||||||
|
def get_strategy_generator(self) -> List[StrategyGenerator]: |
||||||
|
op_data_mapping = self.get_operation_data_mapping() |
||||||
|
generators = [] |
||||||
|
generators.append(SoftmaxGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) |
||||||
|
return generators |
||||||
|
|
||||||
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]: |
||||||
|
# check if the input operand is a parameter |
||||||
|
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): |
||||||
|
data_type = OperationDataType.PARAM |
||||||
|
else: |
||||||
|
data_type = OperationDataType.ARG |
||||||
|
|
||||||
|
input_data = self.node.args[0]._meta_data |
||||||
|
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) |
||||||
|
|
||||||
|
softmax_dim = self.node.kwargs['dim'] |
||||||
|
|
||||||
|
num_dims = self.node.args[0]._meta_data.dim() |
||||||
|
# recover negative value to positive |
||||||
|
if softmax_dim < 0: |
||||||
|
softmax_dim += num_dims |
||||||
|
|
||||||
|
physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) |
||||||
|
|
||||||
|
output_data = self.node._meta_data |
||||||
|
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) |
||||||
|
|
||||||
|
mapping = { |
||||||
|
"input": physical_input_operand, |
||||||
|
"softmax_dim": physical_dim_operand, |
||||||
|
"output": physical_output_operand |
||||||
|
} |
||||||
|
|
||||||
|
return mapping |
@ -0,0 +1,104 @@ |
|||||||
|
import copy |
||||||
|
import operator |
||||||
|
from functools import reduce |
||||||
|
from typing import List |
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator |
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( |
||||||
|
CommAction, |
||||||
|
CommType, |
||||||
|
MemoryCost, |
||||||
|
ShardingStrategy, |
||||||
|
TrainCycleItem, |
||||||
|
) |
||||||
|
from colossalai.auto_parallel.tensor_shard.utils import ( |
||||||
|
check_keep_sharding_status, |
||||||
|
detect_reshape_mapping, |
||||||
|
infer_output_dim_partition_dict, |
||||||
|
) |
||||||
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern |
||||||
|
|
||||||
|
__all__ = ['SoftmaxGenerator'] |
||||||
|
|
||||||
|
|
||||||
|
class SoftmaxGenerator(FollowingStrategyGenerator): |
||||||
|
""" |
||||||
|
SoftmaxGenerator is used to generate strategies for torch.nn.Softmax or F.softmax. |
||||||
|
""" |
||||||
|
|
||||||
|
def validate(self) -> bool: |
||||||
|
return super().validate() |
||||||
|
|
||||||
|
def update_compute_cost(self, strategy: ShardingStrategy): |
||||||
|
''' |
||||||
|
Compute the computation cost per device with this specific strategy. |
||||||
|
''' |
||||||
|
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() |
||||||
|
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() |
||||||
|
input_size_product = reduce(operator.mul, sharded_input_shape) |
||||||
|
output_size_product = reduce(operator.mul, sharded_output_shape) |
||||||
|
|
||||||
|
forward_compute_cost = output_size_product * 2 |
||||||
|
backward_compute_cost = input_size_product |
||||||
|
total_compute_cost = forward_compute_cost + backward_compute_cost |
||||||
|
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) |
||||||
|
strategy.compute_cost = compute_cost |
||||||
|
|
||||||
|
def update_memory_cost(self, strategy: ShardingStrategy): |
||||||
|
''' |
||||||
|
Compute the memory cost per device with this specific strategy. |
||||||
|
''' |
||||||
|
forward_size_mapping = { |
||||||
|
'input': self._compute_size_in_bytes(strategy, "input"), |
||||||
|
'output': self._compute_size_in_bytes(strategy, "output") |
||||||
|
} |
||||||
|
|
||||||
|
backward_size_mapping = copy.deepcopy(forward_size_mapping) |
||||||
|
backward_size_mapping.pop("output") |
||||||
|
# compute fwd cost incurred |
||||||
|
# fwd_cost = input + output |
||||||
|
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) |
||||||
|
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) |
||||||
|
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) |
||||||
|
|
||||||
|
# compute bwd cost incurred |
||||||
|
# bwd_cost = input_grad |
||||||
|
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) |
||||||
|
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) |
||||||
|
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) |
||||||
|
|
||||||
|
# compute total cost |
||||||
|
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, |
||||||
|
parameter=fwd_parameter_cost + bwd_parameter_cost) |
||||||
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) |
||||||
|
strategy.memory_cost = memory_cost |
||||||
|
|
||||||
|
def collate_strategies(self) -> List[ShardingStrategy]: |
||||||
|
strategy_list = [] |
||||||
|
for index, strategy in enumerate(self.predecessor_node.strategies_vector): |
||||||
|
dim_partition_dict_mapping = {} |
||||||
|
communication_action_mapping = {} |
||||||
|
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] |
||||||
|
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) |
||||||
|
softmax_dim = self.op_data['softmax_dim'].data |
||||||
|
|
||||||
|
if softmax_dim in dim_partition_dict_for_input: |
||||||
|
recover_dims = dim_partition_dict_for_input.pop(softmax_dim) |
||||||
|
|
||||||
|
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) |
||||||
|
dim_partition_dict_mapping = { |
||||||
|
"input": dim_partition_dict_for_input, |
||||||
|
"output": dim_partition_dict_for_output, |
||||||
|
} |
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) |
||||||
|
# add index into name to pass the duplicated check |
||||||
|
# we keep same strategies with different name for node merging, and it will not increase the searching space, |
||||||
|
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. |
||||||
|
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' |
||||||
|
|
||||||
|
strategy = self.get_sharding_strategy(name=name, |
||||||
|
sharding_spec_mapping=sharding_spec_mapping, |
||||||
|
communication_action_mapping=communication_action_mapping) |
||||||
|
strategy_list.append(strategy) |
||||||
|
|
||||||
|
return strategy_list |
@ -0,0 +1,186 @@ |
|||||||
|
from functools import partial |
||||||
|
|
||||||
|
import pytest |
||||||
|
import torch |
||||||
|
import torch.multiprocessing as mp |
||||||
|
import torch.nn as nn |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler |
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler import SoftmaxHandler |
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector |
||||||
|
from colossalai.device.device_mesh import DeviceMesh |
||||||
|
from colossalai.fx import ColoGraphModule, ColoTracer |
||||||
|
from colossalai.initialize import launch |
||||||
|
from colossalai.logging import disable_existing_loggers |
||||||
|
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use |
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag |
||||||
|
from colossalai.utils import free_port |
||||||
|
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy |
||||||
|
|
||||||
|
|
||||||
|
class LinearSplitModel(nn.Module): |
||||||
|
|
||||||
|
def __init__(self, softmax_dim): |
||||||
|
super().__init__() |
||||||
|
self.softmax_dim = softmax_dim |
||||||
|
|
||||||
|
def forward(self, input, other): |
||||||
|
linear_node = F.linear(input, other, bias=None) |
||||||
|
softmax_node = F.softmax(linear_node, self.softmax_dim) |
||||||
|
return softmax_node |
||||||
|
|
||||||
|
|
||||||
|
def check_split_handler(rank, softmax_dim, model_cls, world_size, port): |
||||||
|
disable_existing_loggers() |
||||||
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
||||||
|
model = model_cls(softmax_dim=softmax_dim).cuda() |
||||||
|
|
||||||
|
input = torch.rand(8, 16, 64, 32).to('cuda') |
||||||
|
other = torch.rand(64, 32).to('cuda') |
||||||
|
# index of linear node in computation graph |
||||||
|
node_index = 2 |
||||||
|
# total number of linear strategies |
||||||
|
strategy_number = 23 |
||||||
|
|
||||||
|
physical_mesh_id = torch.arange(0, 4) |
||||||
|
mesh_shape = (2, 2) |
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) |
||||||
|
|
||||||
|
numerical_test_for_node_strategy(model=model, |
||||||
|
device_mesh=device_mesh, |
||||||
|
node_index=node_index, |
||||||
|
strategy_number=strategy_number, |
||||||
|
input_args=[input, other], |
||||||
|
meta_arg_names=['input', 'other'], |
||||||
|
node_type='following') |
||||||
|
tracer = ColoTracer() |
||||||
|
|
||||||
|
# graph(): |
||||||
|
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] |
||||||
|
# %other : torch.Tensor [#users=1] = placeholder[target=other] |
||||||
|
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) |
||||||
|
# %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) |
||||||
|
# return split |
||||||
|
graph = tracer.trace(model, |
||||||
|
meta_args={ |
||||||
|
"input": torch.rand(8, 16, 64, 32).to('meta'), |
||||||
|
"other": torch.rand(64, 32).to('meta'), |
||||||
|
}) |
||||||
|
|
||||||
|
gm = ColoGraphModule(model, graph) |
||||||
|
|
||||||
|
previous_mod_node = list(graph.nodes)[2] |
||||||
|
split_node = list(graph.nodes)[3] |
||||||
|
split_strategies_vector = StrategiesVector(split_node) |
||||||
|
previous_strategies_vector = StrategiesVector(previous_mod_node) |
||||||
|
|
||||||
|
# build handler |
||||||
|
assert len(previous_strategies_vector) == 0 |
||||||
|
linear_handler = LinearFunctionHandler(node=previous_mod_node, |
||||||
|
device_mesh=device_mesh, |
||||||
|
strategies_vector=previous_strategies_vector) |
||||||
|
linear_handler.register_strategy(compute_resharding_cost=False) |
||||||
|
setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) |
||||||
|
|
||||||
|
softmax_handler = SoftmaxHandler(node=split_node, |
||||||
|
device_mesh=device_mesh, |
||||||
|
strategies_vector=split_strategies_vector) |
||||||
|
|
||||||
|
softmax_handler.register_strategy(compute_resharding_cost=False) |
||||||
|
|
||||||
|
# check operation data mapping |
||||||
|
mapping = softmax_handler.get_operation_data_mapping() |
||||||
|
|
||||||
|
for name, op_data in mapping.items(): |
||||||
|
op_data: OperationData |
||||||
|
# make sure they have valid values |
||||||
|
assert op_data.data is not None |
||||||
|
|
||||||
|
assert mapping['input'].name == "linear" |
||||||
|
assert mapping['input'].data.is_meta |
||||||
|
assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) |
||||||
|
assert mapping['input'].type == OperationDataType.ARG |
||||||
|
assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) |
||||||
|
|
||||||
|
assert mapping['softmax_dim'].name == "softmax_dim" |
||||||
|
assert mapping['softmax_dim'].data == softmax_dim |
||||||
|
assert mapping['softmax_dim'].type == OperationDataType.ARG |
||||||
|
|
||||||
|
assert mapping['output'].name == "softmax" |
||||||
|
assert mapping['output'].data.shape == torch.Size([8, 16, 64, 64]) |
||||||
|
assert mapping['output'].logical_shape == torch.Size([8, 16, 64, 64]) |
||||||
|
assert mapping['output'].type == OperationDataType.OUTPUT |
||||||
|
|
||||||
|
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. |
||||||
|
assert len(split_strategies_vector) == len(previous_strategies_vector) |
||||||
|
strategy_name_list = [strategy.name for strategy in split_strategies_vector] |
||||||
|
|
||||||
|
if softmax_dim == 0: |
||||||
|
assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list |
||||||
|
assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list |
||||||
|
assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list |
||||||
|
assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list |
||||||
|
assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list |
||||||
|
assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list |
||||||
|
assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list |
||||||
|
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list |
||||||
|
assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list |
||||||
|
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list |
||||||
|
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list |
||||||
|
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list |
||||||
|
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list |
||||||
|
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list |
||||||
|
assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list |
||||||
|
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list |
||||||
|
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list |
||||||
|
|
||||||
|
if softmax_dim == 1: |
||||||
|
assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list |
||||||
|
assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list |
||||||
|
assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list |
||||||
|
assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list |
||||||
|
assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list |
||||||
|
assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list |
||||||
|
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list |
||||||
|
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list |
||||||
|
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list |
||||||
|
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list |
||||||
|
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list |
||||||
|
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list |
||||||
|
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list |
||||||
|
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list |
||||||
|
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list |
||||||
|
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list |
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list |
||||||
|
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list |
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL') |
||||||
|
@pytest.mark.dist |
||||||
|
@rerun_if_address_is_in_use() |
||||||
|
@parameterize('softmax_dim', [0, 1, 2, 3]) |
||||||
|
@parameterize('model_cls', [LinearSplitModel]) |
||||||
|
def test_split_handler(softmax_dim, model_cls): |
||||||
|
world_size = 4 |
||||||
|
run_func = partial(check_split_handler, |
||||||
|
softmax_dim=softmax_dim, |
||||||
|
model_cls=model_cls, |
||||||
|
world_size=world_size, |
||||||
|
port=free_port()) |
||||||
|
mp.spawn(run_func, nprocs=world_size) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
test_split_handler() |
Loading…
Reference in new issue