mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] Add conv handler to generate strategies and costs info for conv (#1467)
parent
1b491ad7de
commit
26a37b5cd5
|
@ -0,0 +1,384 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import torch
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
|
||||
class ConvHandler:
|
||||
'''
|
||||
The ConvHandler is used to generate every possible strategies for a Conv node.
|
||||
|
||||
Argument:
|
||||
input_node(Node): the input node in conv node argument list.
|
||||
input_index(int): the index of input node in the conv node argument list.
|
||||
weight(torch.Tensor): Weight of the conv node.
|
||||
output_node(Node): Output_node is the output of the conv node.
|
||||
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
||||
strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
|
||||
'''
|
||||
|
||||
def __init__(self, input_node, input_index, weight, output_node, device_mesh, strategies_vector,
|
||||
shape_consistency_manager):
|
||||
self.input_node = input_node
|
||||
self.input_data = self.input_node._meta_data
|
||||
self.weight = weight
|
||||
self.input_index = input_index
|
||||
self.output_node = output_node
|
||||
self.output = self.output_node._meta_data
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.shape_consistency_manager = shape_consistency_manager
|
||||
self._sanity_check()
|
||||
|
||||
def _sanity_check(self):
|
||||
'''
|
||||
In sanity check, we need make sure the input data having correct dimension size.
|
||||
For Conv1d, the dim of input data should be 3([N, C, L]).
|
||||
For Conv2d, the dim of input data should be 4([N, C, H, W]).
|
||||
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
|
||||
'''
|
||||
assert self.input_data.dim() in (3, 4,
|
||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||
|
||||
def _generate_sharding_spec_for_input(self, dim_partition_dict_for_input):
|
||||
'''
|
||||
Generate sharding spec for the input node.
|
||||
'''
|
||||
entire_shape_for_input = self.input_data.shape
|
||||
sharding_spec_for_input = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=entire_shape_for_input,
|
||||
dim_partition_dict=dim_partition_dict_for_input)
|
||||
return sharding_spec_for_input
|
||||
|
||||
def _generate_sharding_spec_for_weight(self, dim_partition_dict_for_weight):
|
||||
'''
|
||||
Generate sharding spec for the weight.
|
||||
'''
|
||||
entire_shape_for_weight = self.weight.shape
|
||||
sharding_spec_for_weight = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=entire_shape_for_weight,
|
||||
dim_partition_dict=dim_partition_dict_for_weight)
|
||||
return sharding_spec_for_weight
|
||||
|
||||
def _generate_sharding_spec_for_output(self, dim_partition_dict_for_output):
|
||||
'''
|
||||
Generate sharding spec for the output node.
|
||||
'''
|
||||
entire_shape_for_output = self.output.shape
|
||||
sharding_spec_for_output = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=entire_shape_for_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
return sharding_spec_for_output
|
||||
|
||||
def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Note: The resharding_cost of weight is NOT counted.
|
||||
|
||||
Argument:
|
||||
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
|
||||
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||||
strategy.
|
||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
||||
'''
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs[self.input_index] = []
|
||||
for stategy in self.input_node.strategies_vector.strategies:
|
||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input)
|
||||
resharding_costs[self.input_index].append(resharding_cost)
|
||||
|
||||
def _generate_compute_cost(self, bs, channel_in, channel_out):
|
||||
'''
|
||||
Compute the computation cost per device with this specific strategy.
|
||||
|
||||
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
|
||||
Argument:
|
||||
bs(int): Batch size of the input data.
|
||||
channel_in(int): The channel dimension of input data.
|
||||
channel_out(int): The out channel of the conv weight.
|
||||
|
||||
Return:
|
||||
compute_cost(float): Computation cost per device with this specific strategy
|
||||
'''
|
||||
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
# 1D: (L) * N * Cout * Cin * kernel
|
||||
# 2D: (H * W) * N * Cout * Cin * kernel
|
||||
# 3D: (H * W * D) * N * Cout * Cin * kernel
|
||||
output_size = self.output.shape[2:]
|
||||
output_size_product = reduce(operator.mul, output_size, 1)
|
||||
kernel_size = self.weight.shape[2:]
|
||||
kernel_size_product = reduce(operator.mul, kernel_size, 1)
|
||||
compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||
return compute_cost
|
||||
|
||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.strategies.append(sharding_strategies)
|
||||
|
||||
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.strategies.append(sharding_strategies)
|
||||
|
||||
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.strategies.append(sharding_strategies)
|
||||
|
||||
def split_weight_out_channel(self, mesh_dim_0):
|
||||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.strategies.append(sharding_strategies)
|
||||
|
||||
def non_split(self):
|
||||
name = f'RR = RR x RR'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec_for_input(dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec_for_weight(dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec_for_output(dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
memory_cost = numel * size_per_elem_bytes
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.strategies.append(sharding_strategies)
|
||||
|
||||
def register_strategy_into_strategies_vector(self):
|
||||
'''
|
||||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
||||
|
||||
Example:
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
||||
# return conv
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
# [x, mul, conv, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
|
||||
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
||||
strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
|
||||
conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
|
||||
device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
|
||||
conv_handler.register_strategy_into_strategies_vector()
|
||||
for strategy in conv_handler.strategies_vector.strategies:
|
||||
print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
|
||||
|
||||
Output:
|
||||
S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {0: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
|
||||
S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {0: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
|
||||
S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]}
|
||||
S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]}
|
||||
RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
|
||||
RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
|
||||
RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {0: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
||||
RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {0: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
||||
RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {0: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
||||
'''
|
||||
# SS = SR x RS
|
||||
self.split_input_batch_weight_out_channel(0, 1)
|
||||
self.split_input_batch_weight_out_channel(1, 0)
|
||||
|
||||
# SR = SS x SR
|
||||
self.split_input_both_dim_weight_in_channel(0, 1)
|
||||
self.split_input_both_dim_weight_in_channel(1, 0)
|
||||
|
||||
# RS = RS x SS
|
||||
self.split_input_in_channel_weight_both_channel(0, 1)
|
||||
self.split_input_in_channel_weight_both_channel(1, 0)
|
||||
|
||||
# RS = RR x RS
|
||||
self.split_weight_out_channel(0)
|
||||
self.split_weight_out_channel(1)
|
||||
|
||||
# RR= RR x RR
|
||||
self.non_split()
|
|
@ -0,0 +1,54 @@
|
|||
class ShardingStrategy:
|
||||
'''
|
||||
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
|
||||
and costs information using in solver.
|
||||
|
||||
Argument:
|
||||
name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
|
||||
output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
|
||||
compute_cost(float): Computation cost to complete this strategy.(default to 0)
|
||||
communication_cost(float): Communication cost to complete this strategy.(default to 0)
|
||||
memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
|
||||
resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||||
strategy.(default to None)
|
||||
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
name,
|
||||
output_sharding_spec,
|
||||
compute_cost=0,
|
||||
communication_cost=0,
|
||||
memory_cost=0,
|
||||
resharding_costs=None,
|
||||
input_shardings=None):
|
||||
self.name = name
|
||||
self.output_sharding_spec = output_sharding_spec
|
||||
self.compute_cost = compute_cost
|
||||
self.communication_cost = communication_cost
|
||||
self.memory_cost = memory_cost
|
||||
self.resharding_costs = resharding_costs
|
||||
self.input_shardings = input_shardings
|
||||
|
||||
|
||||
class StrategiesVector:
|
||||
'''
|
||||
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
||||
strategies of the node.
|
||||
|
||||
Argument:
|
||||
node(Node): node to build corresponding strategies_vector.
|
||||
in_nodes(List[Node]): input nodes in the argument list of the node.
|
||||
following_nodes(List[Node]): the nodes take the target node as their argument.
|
||||
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
|
||||
'''
|
||||
|
||||
def __init__(self, node, in_nodes, following_nodes=None, strategies=[]):
|
||||
self.node = node
|
||||
self.in_nodes = in_nodes
|
||||
self.following_nodes = following_nodes
|
||||
self.strategies = strategies
|
||||
|
||||
def check_merge(self):
|
||||
pass
|
|
@ -199,7 +199,7 @@ class ShardingSpec:
|
|||
if not dim_spec.is_replica:
|
||||
if index not in new_dim_partition_dict:
|
||||
new_dim_partition_dict[index] = []
|
||||
new_dim_partition_dict[index].append(dim_spec.shard_list)
|
||||
new_dim_partition_dict[index].extend(dim_spec.shard_list)
|
||||
self.dim_partition_dict = new_dim_partition_dict
|
||||
|
||||
def sharding_sequence_difference(self, other):
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.conv_handler import ConvHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_conv_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
||||
# return conv
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
# [x, mul, conv, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
|
||||
strategies_for_input = []
|
||||
sharding_option = (None, 0, 1)
|
||||
for first_sharding_index in sharding_option:
|
||||
for second_sharding_index in sharding_option:
|
||||
if first_sharding_index is not None and second_sharding_index == first_sharding_index:
|
||||
continue
|
||||
if first_sharding_index is None:
|
||||
first_dim_spec = _DimSpec([])
|
||||
else:
|
||||
first_dim_spec = _DimSpec([first_sharding_index])
|
||||
|
||||
if second_sharding_index is None:
|
||||
second_dim_spec = _DimSpec([])
|
||||
else:
|
||||
second_dim_spec = _DimSpec([second_sharding_index])
|
||||
|
||||
replica_dim_spec = _DimSpec([])
|
||||
sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec]
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||
entire_shape=entire_shape,
|
||||
sharding_sequence=sharding_sequence)
|
||||
strategies_for_input.append(sharding_spec)
|
||||
|
||||
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
||||
strategies_vector_for_input = StrategiesVector(node=nodes[0],
|
||||
in_nodes=[nodes[1], 2],
|
||||
strategies=strategies_for_input)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[
|
||||
nodes[1],
|
||||
])
|
||||
conv_handler = ConvHandler(input_node=nodes[1],
|
||||
input_index=0,
|
||||
weight=dict(gm.named_modules())[nodes[2].name].weight,
|
||||
output_node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
shape_consistency_manager=shape_consistency_manager)
|
||||
conv_handler.register_strategy_into_strategies_vector()
|
||||
|
||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
|
||||
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector.strategies]
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||
|
||||
# RS = RS x SS
|
||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||
assert 'RS1 = RS0 x S0S1' in strategy_name_list
|
||||
|
||||
# RS = RR x RS
|
||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
# RR= RR x RR
|
||||
assert 'RR = RR x RR' in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conv_handler()
|
Loading…
Reference in New Issue