[autoparallel] add reshape handler v2 and fix some previous bug (#1683)

pull/1690/head
YuliangLiu0306 2 years ago committed by GitHub
parent 6878e42248
commit af718e83f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,35 @@
import torch
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector
from ..strategy import ReshapeGenerator, StrategyGenerator_V2
from typing import List, Dict
from .registry import operator_registry
import operator
__all__ = ['ReshapeHandler']
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.Tensor.permute)
class ReshapeHandler(NodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "output": physical_output}
return mapping

@ -5,10 +5,11 @@ from .batch_norm_generator import BatchNormStrategyGenerator
from .unary_elementwise_generator import UnaryElementwiseGenerator from .unary_elementwise_generator import UnaryElementwiseGenerator
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
from .layer_norm_generator import LayerNormGenerator from .layer_norm_generator import LayerNormGenerator
from .reshape_generator import ReshapeGenerator
__all__ = [ __all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', 'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator',
'LayerNormGenerator' 'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator'
] ]

@ -37,7 +37,7 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2):
assert input_op_data.dim() in (3, 4, assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: def update_compute_cost(self, strategy: ShardingStrategy_V2):
''' '''
Compute the computation cost per device with this specific strategy. Compute the computation cost per device with this specific strategy.
@ -62,9 +62,9 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2):
backward_compute_cost += bias_compute_cost backward_compute_cost += bias_compute_cost
total_compute_cost = forward_compute_cost + backward_compute_cost total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: def update_memory_cost(self, strategy: ShardingStrategy_V2):
forward_size_mapping = { forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"), 'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"), 'other': self._compute_size_in_bytes(strategy, "other"),

@ -29,7 +29,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
assert input_op_data.dim() in (3, 4, assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: def update_compute_cost(self, strategy: ShardingStrategy_V2):
''' '''
Compute the computation cost per device with this specific strategy. Compute the computation cost per device with this specific strategy.
@ -67,9 +67,9 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
total_compute_cost = forward_compute_cost + backward_compute_cost total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: def update_memory_cost(self, strategy: ShardingStrategy_V2):
forward_size_mapping = { forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"), 'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"), 'other': self._compute_size_in_bytes(strategy, "other"),

@ -28,10 +28,11 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
def validate(self) -> bool: def validate(self) -> bool:
return super().validate() return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: def update_compute_cost(self, strategy: ShardingStrategy_V2):
return TrainCycleItem(fwd=10, bwd=10, total=20) compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: def update_memory_cost(self, strategy: ShardingStrategy_V2):
''' '''
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' '''
@ -59,7 +60,6 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
parameter=fwd_parameter_cost + bwd_parameter_cost) parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
return super().update_memory_cost(strategy)
class TensorStrategyGenerator(GetItemStrategyGenerator): class TensorStrategyGenerator(GetItemStrategyGenerator):

@ -23,7 +23,7 @@ class LayerNormGenerator(StrategyGenerator_V2):
def validate(self) -> bool: def validate(self) -> bool:
return super().validate() return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: def update_compute_cost(self, strategy: ShardingStrategy_V2):
''' '''
Compute the computation cost per device with this specific strategy. Compute the computation cost per device with this specific strategy.
@ -52,9 +52,9 @@ class LayerNormGenerator(StrategyGenerator_V2):
backward_compute_cost += bias_compute_cost backward_compute_cost += bias_compute_cost
total_compute_cost = forward_compute_cost + backward_compute_cost total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: def update_memory_cost(self, strategy: ShardingStrategy_V2):
''' '''
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' '''
@ -103,6 +103,9 @@ class LayerNormGenerator(StrategyGenerator_V2):
total_mesh_dim_list = [] total_mesh_dim_list = []
for mesh_dim_list in dim_partition.values(): for mesh_dim_list in dim_partition.values():
total_mesh_dim_list.extend(mesh_dim_list) total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
communication_action_mapping = {} communication_action_mapping = {}
other_comm_spec = self.get_communication_spec( other_comm_spec = self.get_communication_spec(

@ -0,0 +1,100 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
from typing import List
import copy
__all__ = ['ReshapeGenerator']
class ReshapeGenerator(FollowingStrategyGenerator):
"""
ReshapeGenerator which deals with the sharding strategies of Reshape Op, such as torch.Tensor.permute.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def generate(self):
strategy_list = []
# For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in
# replica status and let the successor node choose the way to resharding the
# output node. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for reshape function.
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
if isinstance(self.op_data["output"].data, tuple):
dim_partition_dict_for_output = [{} for _ in range(len(self.op_data["output"].data))]
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
total_mesh_dim_list = []
for mesh_dim_list in dim_partition_dict_for_input.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list)
communication_action_mapping["input"] = input_comm_spec
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list

@ -53,9 +53,16 @@ class StrategyGenerator_V2(ABC):
for op_data_name, dim_partition_dict in mapping.items(): for op_data_name, dim_partition_dict in mapping.items():
if op_data_name in self.op_data: if op_data_name in self.op_data:
op_data = self.op_data[op_data_name] op_data = self.op_data[op_data_name]
sharding_spec = ShardingSpec(device_mesh=self.device_mesh, if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor):
entire_shape=op_data.logical_shape, sharding_spec = []
dim_partition_dict=dim_partition_dict) for output, dim_partition_dict_element in zip(op_data.data, dim_partition_dict):
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=output.shape,
dim_partition_dict=dim_partition_dict_element)
else:
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=op_data.logical_shape,
dim_partition_dict=dim_partition_dict)
results[op_data_name] = sharding_spec results[op_data_name] = sharding_spec
return results return results

@ -18,10 +18,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
def validate(self) -> bool: def validate(self) -> bool:
return super().validate() return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: def update_compute_cost(self, strategy: ShardingStrategy_V2):
return TrainCycleItem(fwd=10, bwd=10, total=20) compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: def update_memory_cost(self, strategy: ShardingStrategy_V2):
''' '''
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' '''
@ -49,7 +50,6 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
parameter=fwd_parameter_cost + bwd_parameter_cost) parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
return super().update_memory_cost(strategy)
def generate(self): def generate(self):
strategy_list = [] strategy_list = []

@ -0,0 +1,81 @@
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
from colossalai.auto_parallel.solver.op_handler.reshape_handler_v2 import ReshapeHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
class ReshapeModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, other):
conv_node = nn.functional.conv2d(input, other)
reshape_node = conv_node.view(2, -1)
return reshape_node
def test_reshape_handler():
model = ReshapeModel()
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
# return view
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 64, 64).to('meta'),
"other": torch.rand(4, 16, 3, 3).to('meta'),
})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[2]
reshape_node = list(graph.nodes)[3]
reshape_strategies_vector = StrategiesVector(reshape_node)
conv_strategies_vector = StrategiesVector(conv_mod_node)
# build handler
conv_handler = ConvFunctionHandler(node=conv_mod_node,
device_mesh=device_mesh,
strategies_vector=conv_strategies_vector)
conv_handler.register_strategy()
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
reshape_handler = ReshapeHandler(node=reshape_node,
device_mesh=device_mesh,
strategies_vector=reshape_strategies_vector)
reshape_handler.register_strategy()
# check operation data mapping
mapping = reshape_handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.data is not None
assert mapping['input'].name == "conv2d"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62])
assert mapping['output'].name == "view"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([2, 30752])
assert mapping['output'].type == OperationDataType.OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert len(reshape_strategies_vector) == len(conv_strategies_vector)
if __name__ == '__main__':
test_reshape_handler()
Loading…
Cancel
Save