[autoparallel] add conv handler v2 (#1663)

pull/1664/head
YuliangLiu0306 2022-09-28 11:24:59 +08:00 committed by GitHub
parent 1e7816a460
commit 095854477f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 849 additions and 4 deletions

View File

@ -0,0 +1,145 @@
import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
from ..strategy import ConvStrategyGenerator, StrategyGenerator_V2
from typing import List, Dict
from .registry import operator_registry
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
@operator_registry.register(torch.nn.Conv1d)
@operator_registry.register(torch.nn.Conv2d)
@operator_registry.register(torch.nn.Conv3d)
class ConvModuleHandler(ModuleHandler):
"""
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
"""
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
logical_shape_for_weight = list(self.named_parameters["weight"].shape)
logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
1], logical_shape_for_weight[0]
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'],
logical_shape=torch.Size(logical_shape_for_weight))
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.named_parameters['bias'] is not None:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
mapping['bias'] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy_V2):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and second dim of the conv module weight
first_dim_partition = dim_partition_dict.pop(1, None)
second_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if second_dim_partition:
dim_partition_dict[1] = second_dim_partition
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
return strategy
@operator_registry.register(F.conv1d)
@operator_registry.register(F.conv2d)
@operator_registry.register(F.conv3d)
class ConvFunctionHandler(NodeHandler):
"""
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
"""
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
logical_shape_for_weight = list(self.node.args[1]._meta_data.shape)
logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
1], logical_shape_for_weight[0]
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=data_type,
data=self.node.args[1]._meta_data,
logical_shape=torch.Size(logical_shape_for_weight))
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.node.kwargs:
# check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
type=data_type,
data=self.node.kwargs["bias"]._meta_data)
mapping['bias'] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy_V2):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and second dim of the conv function weight
first_dim_partition = dim_partition_dict.pop(1, None)
second_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if second_dim_partition:
dim_partition_dict[1] = second_dim_partition
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
return strategy

View File

@ -82,8 +82,6 @@ class ModuleHandler(NodeHandler):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
print("created")
# set attributes to access module parameters for convenience
assert self.node.graph.owning_module is not None, \
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'

View File

@ -1,7 +1,8 @@
from .strategy_generator import StrategyGenerator_V2
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator
__all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator'
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator'
]

View File

@ -0,0 +1,491 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2
from typing import List
from .._utils import exception_handler
import copy
class ConvStrategyGenerator(StrategyGenerator_V2):
"""
ConvStrategyGenerator is a generic class to generate strategies.
The operation data is defined as `output = input x other + bias`.
"""
@property
def has_bias(self):
return 'bias' in self.op_data
def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.
For Conv1d, the dim of input data should be 3([N, C, L]).
For Conv2d, the dim of input data should be 4([N, C, H, W]).
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
output_size = sharded_output_shape[2:]
output_size_product = reduce(operator.mul, output_size)
input_size = sharded_input_shape[2:]
input_size_product = reduce(operator.mul, input_size, 1)
kernel_size = sharded_other_shape[2:]
kernel_size_product = reduce(operator.mul, kernel_size, 1)
batch_size = sharded_input_shape[0]
channel_in = sharded_input_shape[1]
channel_out = sharded_other_shape[1]
forward_compute_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product
backward_activation_cost = input_size_product * batch_size * channel_in * channel_out * kernel_size_product
backward_weight_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product
backward_compute_cost = backward_weight_cost + backward_activation_cost
if self.has_bias:
forward_compute_cost += bias_compute_cost
backward_compute_cost += bias_compute_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_activation_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_activation_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {
1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1]
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {"input": input_comm_spec}
if self.is_param("other"):
other_comm_spec = self.get_communication_spec(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["other"] = other_comm_spec
if self.has_bias and self.is_param("bias"):
bias_comm_spec = self.get_communication_spec(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["bias"] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {},
"output": {
0: [mesh_dim_0],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
if self.is_param("other"):
other_comm_spec = self.get_communication_spec(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["other"] = other_comm_spec
if self.has_bias and self.is_param("bias"):
bias_comm_spec = self.get_communication_spec(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["bias"] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
"other": {
0: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_spec = self.get_communication_spec(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {"output": output_comm_spec}
if self.is_param("other"):
other_comm_spec = self.get_communication_spec(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["other"] = other_comm_spec
if self.has_bias and self.is_param("bias"):
bias_comm_spec = self.get_communication_spec(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping["bias"] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0],
},
"other": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
"output": {
1: [mesh_dim_1],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {
0: [mesh_dim_1],
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_spec = self.get_communication_spec(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
input_comm_spec = self.get_communication_spec(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0],
},
"other": {
0: [mesh_dim_0],
},
"output": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_spec = self.get_communication_spec(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {"output": output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
dim_partition_dict_mapping = {
"input": {},
"other": {
1: [mesh_dim_0],
},
"output": {
1: [mesh_dim_0],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {
0: [mesh_dim_0],
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_spec = self.get_communication_spec(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {"input": input_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def non_split(self):
name = f'RR = RR x RR'
dim_partition_dict_mapping = {
"input": {},
"other": {},
"output": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1],
},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
if self.is_param("other"):
other_comm_spec = self.get_communication_spec(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping["other"] = other_comm_spec
if self.has_bias and self.is_param("bias"):
bias_comm_spec = self.get_communication_spec(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping["bias"] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0, mesh_dim_1],
},
"other": {
0: [mesh_dim_0, mesh_dim_1],
},
"output": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
output_comm_spec = self.get_communication_spec(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {"output": output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {},
"other": {
1: [mesh_dim_0, mesh_dim_1],
},
"output": {
1: [mesh_dim_0, mesh_dim_1],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {
0: [mesh_dim_0, mesh_dim_1],
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_spec = self.get_communication_spec(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {"input": input_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]:
strategies = []
# SS = SR x RS
strategies.append(self.split_input_batch_weight_out_channel(0, 1))
strategies.append(self.split_input_batch_weight_out_channel(1, 0))
# SR = SR x RR
strategies.append(self.split_input_batch(0))
strategies.append(self.split_input_batch(1))
# SR = SS x SR
strategies.append(self.split_input_both_dim_weight_in_channel(0, 1))
strategies.append(self.split_input_both_dim_weight_in_channel(1, 0))
# RS = RS x SS
strategies.append(self.split_input_in_channel_weight_both_channel(0, 1))
strategies.append(self.split_input_in_channel_weight_both_channel(1, 0))
# RR = RS x SR
strategies.append(self.split_input_in_channel_weight_in_channel(0))
strategies.append(self.split_input_in_channel_weight_in_channel(1))
# RS = RR x RS
strategies.append(self.split_weight_out_channel(0))
strategies.append(self.split_weight_out_channel(1))
# RR= RR x RR
strategies.append(self.non_split())
# S01R = S01R x RR
strategies.append(self.split_1d_parallel_on_input_batch(0, 1))
# RR = RS01 x S01R
strategies.append(self.split_1d_parallel_on_in_channel(0, 1))
# RS01 = RR x RS01
strategies.append(self.split_1d_parallel_on_out_channel(0, 1))
# update mete info on cost
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategies

View File

@ -0,0 +1,210 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvModuleHandler, ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
def test_conv_module_handler():
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta'))
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(conv_mod_node)
# build handler
handler = ConvModuleHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['bias'].logical_shape == torch.Size([16])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SR x RR
assert 'S0R = S0R x RR' in strategy_name_list
assert 'S1R = S1R x RR' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
# RR = RR x RR
assert 'RR = RR x RR' in strategy_name_list
# S01R = S01R x RR
assert 'S01R = S01R x RR' in strategy_name_list
# RR = RS01 x S01R
assert 'RR = RS01 x S01R' in strategy_name_list
# RS01 = RR x RS01
assert 'RS01 = RR x RS01' in strategy_name_list
class ConvModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, others, bias=None):
x = nn.functional.conv2d(input, others, bias=bias, padding=1)
return x
def test_conv_function_handler():
model = ConvModel()
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %others : torch.Tensor [#users=1] = placeholder[target=others]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {})
# return conv2d
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 64, 64).to('meta'),
"others": torch.rand(16, 4, 3, 3).to('meta'),
"bias": torch.rand(16).to('meta')
})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(conv_mod_node)
# build handler
handler = ConvFunctionHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64])
assert mapping['other'].name == "others"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3])
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([16])
assert mapping['output'].name == "conv2d"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SR x RR
assert 'S0R = S0R x RR' in strategy_name_list
assert 'S1R = S1R x RR' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
# RR = RR x RR
assert 'RR = RR x RR' in strategy_name_list
# S01R = S01R x RR
assert 'S01R = S01R x RR' in strategy_name_list
# RR = RS01 x S01R
assert 'RR = RS01 x S01R' in strategy_name_list
# RS01 = RR x RS01
assert 'RS01 = RR x RS01' in strategy_name_list
if __name__ == '__main__':
test_conv_module_handler()
test_conv_function_handler()

View File

@ -48,7 +48,7 @@ def test_linear_module_handler():
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['bias'].logical_shape == torch.Size([32])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta