[autoparallel] adapt handlers with attention block (#1990)

* [autoparallel] adapt handlers with attention block

* polish
pull/1999/head
YuliangLiu0306 2022-11-21 10:44:11 +08:00 committed by GitHub
parent b5dbb46172
commit 35e6b9ec82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 114 additions and 33 deletions

View File

@ -12,6 +12,7 @@ __all__ = ['ReshapeHandler']
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.Tensor.split)
@operator_registry.register(torch.split)
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.transpose)
@operator_registry.register(torch.Tensor.permute)

View File

@ -220,7 +220,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
logical_process_axis=mesh_dim_0,
comm_type=CommType.IMPLICIT)
communication_action_mapping = {"output": output_comm_action}
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -256,7 +258,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.IMPLICIT)
communication_action_mapping = {"output": output_comm_action}
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -302,7 +306,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
logical_process_axis=[mesh_dim_0],
comm_type=CommType.IMPLICIT)
communication_action_mapping = {"output": output_comm_action}
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,

View File

@ -69,7 +69,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for strategy in self.predecessor_node.strategies_vector:
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
@ -96,7 +96,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
arg_index=0)
communication_action_mapping["input"] = input_communication_action
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -121,7 +121,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
strategy_list = []
index = self.op_data["index"].data
for strategy in self.predecessor_node.strategies_vector:
for strategy_index, strategy in enumerate(self.predecessor_node.strategies_vector):
# the sharding spec for input in this case is a tuple of ShardingSpec.
sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict
@ -132,8 +132,11 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
sharding_spec_mapping["input"] = sharding_spec_for_input
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
input_sharding_info = f"get the {index} element from ("
for sharding_spec in sharding_spec_for_input:
input_sharding_info += f'{sharding_spec.sharding_sequence}, '
input_sharding_info += ")"
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,

View File

@ -1,9 +1,12 @@
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from .strategy_generator import StrategyGenerator
@ -50,6 +53,7 @@ class WhereGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {
"condition": dim_partition,

View File

@ -14,6 +14,11 @@ __all__ = ['UnaryElementwiseHandler']
@operator_registry.register(torch.Tensor.type)
@operator_registry.register(torch.abs)
@operator_registry.register(torch.nn.ReLU)
# TODO: softmax need to be relocated
@operator_registry.register(torch.nn.functional.softmax)
@operator_registry.register(torch.nn.modules.dropout.Dropout)
@operator_registry.register(torch.Tensor.contiguous)
@operator_registry.register(torch.nn.functional.dropout)
class UnaryElementwiseHandler(NodeHandler):
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.

View File

@ -57,24 +57,6 @@ class WhereHandler(NodeHandler):
logical_operand.logical_shape = target_shape
return logical_operand
def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector:
"""
Register different sharding strategies for the current node.
"""
strategy_generators = self.get_strategy_generator()
for generator in strategy_generators:
strategies = generator.generate()
strategies_vector = map(self.post_process, strategies)
# compute the resharding costs based on the previous node
# strategies if specified
if compute_resharding_cost:
strategies = list(map(self.update_resharding_cost, strategies))
self.strategies_vector.extend(strategies)
self.strategies_vector = list(strategies_vector)
return self.strategies_vector
def post_process(self, strategy: ShardingStrategy):
logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping()
for key in logical_op_data_mapping.keys():

View File

@ -3,6 +3,8 @@ import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
@ -10,7 +12,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class GetItemModel(nn.Module):
class GetItemFromTensorModel(nn.Module):
def __init__(self):
super().__init__()
@ -21,8 +23,8 @@ class GetItemModel(nn.Module):
return x
def test_getitem_function_handler():
model = GetItemModel()
def test_getitem_from_tensor_handler():
model = GetItemFromTensorModel()
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -83,5 +85,83 @@ def test_getitem_function_handler():
assert len(getitem_strategies_vector) == len(conv_strategies_vector)
class GetItemFromTupleModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
split_node = torch.split(input, 2, 0)
x = split_node[1]
return x
def test_getitem_from_tuple_handler():
model = GetItemFromTupleModel()
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0})
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
# return getitem
graph = tracer.trace(model, meta_args={
"input": torch.rand(4, 4, 64, 64).to('meta'),
})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
input_node = list(graph.nodes)[0]
split_node = list(graph.nodes)[1]
getitem_node = list(graph.nodes)[2]
input_strategies_vector = StrategiesVector(input_node)
getitem_strategies_vector = StrategiesVector(getitem_node)
split_strategies_vector = StrategiesVector(split_node)
# build handler
input_handler = PlacehodlerHandler(
node=input_node,
device_mesh=device_mesh,
strategies_vector=input_strategies_vector,
placeholder_option='replicated',
)
input_handler.register_strategy(compute_resharding_cost=False)
setattr(input_node, 'strategies_vector', input_strategies_vector)
split_handler = ReshapeHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector)
split_handler.register_strategy(compute_resharding_cost=False)
setattr(split_node, 'strategies_vector', split_strategies_vector)
getitem_handler = GetItemHandler(node=getitem_node,
device_mesh=device_mesh,
strategies_vector=getitem_strategies_vector)
getitem_handler.register_strategy(compute_resharding_cost=False)
setattr(getitem_node, 'strategies_vector', getitem_strategies_vector)
# check operation data mapping
mapping = getitem_handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.data is not None
assert mapping['input'].name == "split"
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64]))
assert mapping['index'].name == "index"
assert isinstance(mapping['index'].data, int)
assert mapping['index'].type == OperationDataType.ARG
assert mapping['output'].name == "getitem"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([2, 4, 64, 64])
assert mapping['output'].type == OperationDataType.OUTPUT
# getitem is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert len(getitem_strategies_vector) == len(split_strategies_vector)
if __name__ == '__main__':
test_getitem_function_handler()
test_getitem_from_tensor_handler()
test_getitem_from_tuple_handler()