diff --git a/colossalai/auto_parallel/solver/conv_handler.py b/colossalai/auto_parallel/solver/conv_handler.py new file mode 100644 index 000000000..a6f3a682b --- /dev/null +++ b/colossalai/auto_parallel/solver/conv_handler.py @@ -0,0 +1,384 @@ +import operator +from functools import reduce +import torch +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector + + +class ConvHandler: + ''' + The ConvHandler is used to generate every possible strategies for a Conv node. + + Argument: + input_node(Node): the input node in conv node argument list. + input_index(int): the index of input node in the conv node argument list. + weight(torch.Tensor): Weight of the conv node. + output_node(Node): Output_node is the output of the conv node. + device_mesh(DeviceMesh): A logical view of a physical mesh. + strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. + shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs. + ''' + + def __init__(self, input_node, input_index, weight, output_node, device_mesh, strategies_vector, + shape_consistency_manager): + self.input_node = input_node + self.input_data = self.input_node._meta_data + self.weight = weight + self.input_index = input_index + self.output_node = output_node + self.output = self.output_node._meta_data + self.device_mesh = device_mesh + self.strategies_vector = strategies_vector + self.shape_consistency_manager = shape_consistency_manager + self._sanity_check() + + def _sanity_check(self): + ''' + In sanity check, we need make sure the input data having correct dimension size. + For Conv1d, the dim of input data should be 3([N, C, L]). + For Conv2d, the dim of input data should be 4([N, C, H, W]). + For Conv3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + assert self.input_data.dim() in (3, 4, + 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + + def _generate_sharding_spec_for_input(self, dim_partition_dict_for_input): + ''' + Generate sharding spec for the input node. + ''' + entire_shape_for_input = self.input_data.shape + sharding_spec_for_input = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=entire_shape_for_input, + dim_partition_dict=dim_partition_dict_for_input) + return sharding_spec_for_input + + def _generate_sharding_spec_for_weight(self, dim_partition_dict_for_weight): + ''' + Generate sharding spec for the weight. + ''' + entire_shape_for_weight = self.weight.shape + sharding_spec_for_weight = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=entire_shape_for_weight, + dim_partition_dict=dim_partition_dict_for_weight) + return sharding_spec_for_weight + + def _generate_sharding_spec_for_output(self, dim_partition_dict_for_output): + ''' + Generate sharding spec for the output node. + ''' + entire_shape_for_output = self.output.shape + sharding_spec_for_output = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=entire_shape_for_output, + dim_partition_dict=dim_partition_dict_for_output) + return sharding_spec_for_output + + def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input): + ''' + Compute the resharding costs with this specific strategy. + + Note: The resharding_cost of weight is NOT counted. + + Argument: + resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary. + Resharding_cost[i][j] means the cost of i-th argument in the output node argument list + with j-th strategy in its strategies_vector transforms to sharding spec wanted in this + strategy. + sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node. + ''' + # The resharding_cost of weight is counted due to sharing weight cases. + resharding_costs[self.input_index] = [] + for stategy in self.input_node.strategies_vector.strategies: + _, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input) + resharding_costs[self.input_index].append(resharding_cost) + + def _generate_compute_cost(self, bs, channel_in, channel_out): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + + Argument: + bs(int): Batch size of the input data. + channel_in(int): The channel dimension of input data. + channel_out(int): The out channel of the conv weight. + + Return: + compute_cost(float): Computation cost per device with this specific strategy + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # 1D: (L) * N * Cout * Cin * kernel + # 2D: (H * W) * N * Cout * Cin * kernel + # 3D: (H * W * D) * N * Cout * Cin * kernel + output_size = self.output.shape[2:] + output_size_product = reduce(operator.mul, output_size, 1) + kernel_size = self.weight.shape[2:] + kernel_size_product = reduce(operator.mul, kernel_size, 1) + compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product + return compute_cost + + def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + + dim_partition_dict_for_input = {0: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {1: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = {} + self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] + channel_in = self.input_data.shape[1] + channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + memory_cost = numel * size_per_elem_bytes / sharding_size + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.strategies.append(sharding_strategies) + + def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + + dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0]} + sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0]} + sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = {} + self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] + channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1] + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + sharding_size = self.device_mesh.shape[mesh_dim_0] + memory_cost = numel * size_per_elem_bytes / sharding_size + + # compute the communication cost of this strategy + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.strategies.append(sharding_strategies) + + def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + + dim_partition_dict_for_input = {1: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_1]} + sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = {} + self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0] + channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + sharding_size = self.device_mesh.shape[mesh_dim_0] + memory_cost = numel * size_per_elem_bytes / sharding_size + + # compute the communication cost of this strategy + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.strategies.append(sharding_strategies) + + def split_weight_out_channel(self, mesh_dim_0): + name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {1: [mesh_dim_0]} + sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_0]} + sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = {} + self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] + channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_0] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + sharding_size = self.device_mesh.shape[mesh_dim_0] + memory_cost = numel * size_per_elem_bytes / sharding_size + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.strategies.append(sharding_strategies) + + def non_split(self): + name = f'RR = RR x RR' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = {} + self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + memory_cost = numel * size_per_elem_bytes + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.strategies.append(sharding_strategies) + + def register_strategy_into_strategies_vector(self): + ''' + Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. + + Example: + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) + # return conv + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + # [x, mul, conv, output] + nodes = [node for node in gm.graph.nodes] + + # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] + strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input) + setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) + + strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ]) + conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2], + device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) + conv_handler.register_strategy_into_strategies_vector() + for strategy in conv_handler.strategies_vector.strategies: + print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}') + + Output: + S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {0: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} + S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {0: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} + S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]} + S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]} + RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]} + RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]} + RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {0: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} + RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {0: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} + RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {0: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} + ''' + # SS = SR x RS + self.split_input_batch_weight_out_channel(0, 1) + self.split_input_batch_weight_out_channel(1, 0) + + # SR = SS x SR + self.split_input_both_dim_weight_in_channel(0, 1) + self.split_input_both_dim_weight_in_channel(1, 0) + + # RS = RS x SS + self.split_input_in_channel_weight_both_channel(0, 1) + self.split_input_in_channel_weight_both_channel(1, 0) + + # RS = RR x RS + self.split_weight_out_channel(0) + self.split_weight_out_channel(1) + + # RR= RR x RR + self.non_split() diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py new file mode 100644 index 000000000..8f227076d --- /dev/null +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -0,0 +1,54 @@ +class ShardingStrategy: + ''' + ShardingStrategy is a structure containing sharding strategies of inputs and output of this node + and costs information using in solver. + + Argument: + name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'. + output_sharding_spec(ShardingSpec): ShardingSpec of the output node. + compute_cost(float): Computation cost to complete this strategy.(default to 0) + communication_cost(float): Communication cost to complete this strategy.(default to 0) + memory_cost(float): Memory cost of the output node using this strategy.(default to 0) + resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list + with j-th strategy in its strategies_vector transforms to sharding spec wanted in this + strategy.(default to None) + input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes. + ''' + + def __init__(self, + name, + output_sharding_spec, + compute_cost=0, + communication_cost=0, + memory_cost=0, + resharding_costs=None, + input_shardings=None): + self.name = name + self.output_sharding_spec = output_sharding_spec + self.compute_cost = compute_cost + self.communication_cost = communication_cost + self.memory_cost = memory_cost + self.resharding_costs = resharding_costs + self.input_shardings = input_shardings + + +class StrategiesVector: + ''' + Each node in fx graph will have a corresponding StrategiesVector, to store all the possible + strategies of the node. + + Argument: + node(Node): node to build corresponding strategies_vector. + in_nodes(List[Node]): input nodes in the argument list of the node. + following_nodes(List[Node]): the nodes take the target node as their argument. + strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node. + ''' + + def __init__(self, node, in_nodes, following_nodes=None, strategies=[]): + self.node = node + self.in_nodes = in_nodes + self.following_nodes = following_nodes + self.strategies = strategies + + def check_merge(self): + pass diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index 15643b5a2..50083ca5b 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -199,7 +199,7 @@ class ShardingSpec: if not dim_spec.is_replica: if index not in new_dim_partition_dict: new_dim_partition_dict[index] = [] - new_dim_partition_dict[index].append(dim_spec.shard_list) + new_dim_partition_dict[index].extend(dim_spec.shard_list) self.dim_partition_dict = new_dim_partition_dict def sharding_sequence_difference(self, other): diff --git a/tests/test_auto_parallel/test_conv_handler.py b/tests/test_auto_parallel/test_conv_handler.py new file mode 100644 index 000000000..33a4a6a60 --- /dev/null +++ b/tests/test_auto_parallel/test_conv_handler.py @@ -0,0 +1,115 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.auto_parallel.solver.conv_handler import ConvHandler +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv = nn.Conv2d(c_in, c_out, kernel_size=3) + + def forward(self, x): + x = x * 2 + x = self.conv(x) + return x + + +def test_conv_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 16, 64, 64)) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) + # return conv + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + # [x, mul, conv, output] + nodes = [node for node in gm.graph.nodes] + + strategies_for_input = [] + sharding_option = (None, 0, 1) + for first_sharding_index in sharding_option: + for second_sharding_index in sharding_option: + if first_sharding_index is not None and second_sharding_index == first_sharding_index: + continue + if first_sharding_index is None: + first_dim_spec = _DimSpec([]) + else: + first_dim_spec = _DimSpec([first_sharding_index]) + + if second_sharding_index is None: + second_dim_spec = _DimSpec([]) + else: + second_dim_spec = _DimSpec([second_sharding_index]) + + replica_dim_spec = _DimSpec([]) + sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec] + sharding_spec = ShardingSpec(device_mesh=device_mesh, + entire_shape=entire_shape, + sharding_sequence=sharding_sequence) + strategies_for_input.append(sharding_spec) + + # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] + strategies_vector_for_input = StrategiesVector(node=nodes[0], + in_nodes=[nodes[1], 2], + strategies=strategies_for_input) + setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) + + strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[ + nodes[1], + ]) + conv_handler = ConvHandler(input_node=nodes[1], + input_index=0, + weight=dict(gm.named_modules())[nodes[2].name].weight, + output_node=nodes[2], + device_mesh=device_mesh, + strategies_vector=strategies_vector, + shape_consistency_manager=shape_consistency_manager) + conv_handler.register_strategy_into_strategies_vector() + + # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] + strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector.strategies] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RS = RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # RR= RR x RR + assert 'RR = RR x RR' in strategy_name_list + + +if __name__ == '__main__': + test_conv_handler()