mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] adapt handlers with attention block (#1990)
* [autoparallel] adapt handlers with attention block * polishpull/1999/head
parent
b5dbb46172
commit
35e6b9ec82
|
@ -12,6 +12,7 @@ __all__ = ['ReshapeHandler']
|
|||
|
||||
@operator_registry.register(torch.reshape)
|
||||
@operator_registry.register(torch.Tensor.split)
|
||||
@operator_registry.register(torch.split)
|
||||
@operator_registry.register(torch.flatten)
|
||||
@operator_registry.register(torch.Tensor.transpose)
|
||||
@operator_registry.register(torch.Tensor.permute)
|
||||
|
|
|
@ -220,7 +220,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.IMPLICIT)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
# TODO: Temporary solution has no communication cost,
|
||||
# above action should be added after the SyncBN replace pass completed.
|
||||
communication_action_mapping = {}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -256,7 +258,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.IMPLICIT)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
# TODO: Temporary solution has no communication cost,
|
||||
# above action should be added after the SyncBN replace pass completed.
|
||||
communication_action_mapping = {}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -302,7 +306,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
logical_process_axis=[mesh_dim_0],
|
||||
comm_type=CommType.IMPLICIT)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
# TODO: Temporary solution has no communication cost,
|
||||
# above action should be added after the SyncBN replace pass completed.
|
||||
communication_action_mapping = {}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
|
|
@ -69,7 +69,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
|
|||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
for strategy in self.predecessor_node.strategies_vector:
|
||||
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||
dim_partition_dict_mapping = {}
|
||||
communication_action_mapping = {}
|
||||
dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
|
||||
|
@ -96,7 +96,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
|
|||
arg_index=0)
|
||||
communication_action_mapping["input"] = input_communication_action
|
||||
|
||||
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
|
||||
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -121,7 +121,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
|
|||
strategy_list = []
|
||||
index = self.op_data["index"].data
|
||||
|
||||
for strategy in self.predecessor_node.strategies_vector:
|
||||
for strategy_index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||
# the sharding spec for input in this case is a tuple of ShardingSpec.
|
||||
sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]]
|
||||
dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict
|
||||
|
@ -132,8 +132,11 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
|
|||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
sharding_spec_mapping["input"] = sharding_spec_for_input
|
||||
|
||||
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
|
||||
input_sharding_info = f"get the {index} element from ("
|
||||
for sharding_spec in sharding_spec_for_input:
|
||||
input_sharding_info += f'{sharding_spec.sharding_sequence}, '
|
||||
input_sharding_info += ")"
|
||||
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import copy
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
ignore_sharding_exception,
|
||||
)
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
||||
|
@ -50,6 +53,7 @@ class WhereGenerator(StrategyGenerator):
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _generate_strategy_with_dim_partition(self, dim_partition):
|
||||
dim_partition_dict_mapping = {
|
||||
"condition": dim_partition,
|
||||
|
|
|
@ -14,6 +14,11 @@ __all__ = ['UnaryElementwiseHandler']
|
|||
@operator_registry.register(torch.Tensor.type)
|
||||
@operator_registry.register(torch.abs)
|
||||
@operator_registry.register(torch.nn.ReLU)
|
||||
# TODO: softmax need to be relocated
|
||||
@operator_registry.register(torch.nn.functional.softmax)
|
||||
@operator_registry.register(torch.nn.modules.dropout.Dropout)
|
||||
@operator_registry.register(torch.Tensor.contiguous)
|
||||
@operator_registry.register(torch.nn.functional.dropout)
|
||||
class UnaryElementwiseHandler(NodeHandler):
|
||||
"""
|
||||
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
|
||||
|
|
|
@ -57,24 +57,6 @@ class WhereHandler(NodeHandler):
|
|||
logical_operand.logical_shape = target_shape
|
||||
return logical_operand
|
||||
|
||||
def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector:
|
||||
"""
|
||||
Register different sharding strategies for the current node.
|
||||
"""
|
||||
strategy_generators = self.get_strategy_generator()
|
||||
|
||||
for generator in strategy_generators:
|
||||
strategies = generator.generate()
|
||||
strategies_vector = map(self.post_process, strategies)
|
||||
# compute the resharding costs based on the previous node
|
||||
# strategies if specified
|
||||
if compute_resharding_cost:
|
||||
strategies = list(map(self.update_resharding_cost, strategies))
|
||||
self.strategies_vector.extend(strategies)
|
||||
|
||||
self.strategies_vector = list(strategies_vector)
|
||||
return self.strategies_vector
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy):
|
||||
logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping()
|
||||
for key in logical_op_data_mapping.keys():
|
||||
|
|
|
@ -3,6 +3,8 @@ import torch.nn as nn
|
|||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
@ -10,7 +12,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
|||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class GetItemModel(nn.Module):
|
||||
class GetItemFromTensorModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -21,8 +23,8 @@ class GetItemModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def test_getitem_function_handler():
|
||||
model = GetItemModel()
|
||||
def test_getitem_from_tensor_handler():
|
||||
model = GetItemFromTensorModel()
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
|
@ -83,5 +85,83 @@ def test_getitem_function_handler():
|
|||
assert len(getitem_strategies_vector) == len(conv_strategies_vector)
|
||||
|
||||
|
||||
class GetItemFromTupleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
split_node = torch.split(input, 2, 0)
|
||||
x = split_node[1]
|
||||
return x
|
||||
|
||||
|
||||
def test_getitem_from_tuple_handler():
|
||||
model = GetItemFromTupleModel()
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0})
|
||||
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
|
||||
# return getitem
|
||||
graph = tracer.trace(model, meta_args={
|
||||
"input": torch.rand(4, 4, 64, 64).to('meta'),
|
||||
})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
input_node = list(graph.nodes)[0]
|
||||
split_node = list(graph.nodes)[1]
|
||||
getitem_node = list(graph.nodes)[2]
|
||||
input_strategies_vector = StrategiesVector(input_node)
|
||||
getitem_strategies_vector = StrategiesVector(getitem_node)
|
||||
split_strategies_vector = StrategiesVector(split_node)
|
||||
|
||||
# build handler
|
||||
input_handler = PlacehodlerHandler(
|
||||
node=input_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=input_strategies_vector,
|
||||
placeholder_option='replicated',
|
||||
)
|
||||
input_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(input_node, 'strategies_vector', input_strategies_vector)
|
||||
split_handler = ReshapeHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector)
|
||||
split_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(split_node, 'strategies_vector', split_strategies_vector)
|
||||
getitem_handler = GetItemHandler(node=getitem_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=getitem_strategies_vector)
|
||||
getitem_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(getitem_node, 'strategies_vector', getitem_strategies_vector)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = getitem_handler.get_operation_data_mapping()
|
||||
|
||||
for name, op_data in mapping.items():
|
||||
op_data: OperationData
|
||||
# make sure they have valid values
|
||||
assert op_data.data is not None
|
||||
|
||||
assert mapping['input'].name == "split"
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64]))
|
||||
|
||||
assert mapping['index'].name == "index"
|
||||
assert isinstance(mapping['index'].data, int)
|
||||
assert mapping['index'].type == OperationDataType.ARG
|
||||
|
||||
assert mapping['output'].name == "getitem"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([2, 4, 64, 64])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
# getitem is a following strategy handler, so the number of strategies is equal to the predecessor node.
|
||||
assert len(getitem_strategies_vector) == len(split_strategies_vector)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_getitem_function_handler()
|
||||
test_getitem_from_tensor_handler()
|
||||
test_getitem_from_tuple_handler()
|
||||
|
|
Loading…
Reference in New Issue