[autoparallel]add backward cost info into strategies (#1524)

pull/1583/head
YuliangLiu0306 2022-09-07 11:19:00 +08:00 committed by GitHub
parent 1a3599410d
commit 0908d0fc61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 142 additions and 63 deletions

View File

@ -49,11 +49,59 @@ class ConvHandler(OperatorHandler):
# 3D: (H * W * D) * N * Cout * Cin * kernel
output_size = self.output_data.shape[2:]
output_size_product = reduce(operator.mul, output_size, 1)
input_size = self.input_data.shape[2:]
input_size_product = reduce(operator.mul, input_size, 1)
kernel_size = self.weight.shape[2:]
kernel_size_product = reduce(operator.mul, kernel_size, 1)
compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product
backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_backward_weight(int): The backward weight will be divided
into sharding_size_backward_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
numel_input = self.input_data.numel()
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward = numel_output * size_per_elem_bytes / sharding_size_forward
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_backward_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward, memory_cost_backward_activation
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
@ -76,14 +124,19 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost = numel * size_per_elem_bytes / sharding_size
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
# This strategy do not need to do all_reduce operation during forward
communication_cost_forward = 0
# compute the backward communication cost of this strategy
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
# total communication cost
communication_cost = communication_cost_forward + communication_cost_backward
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@ -115,13 +168,13 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight)
# This strategy do not need to do all_reduce operation
# This strategy do not need to do all_reduce operation in both forward and backward phase.
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
@ -154,14 +207,18 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_backward_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_1)
# This strategy do not need to do all_reduce operation during backward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@ -193,14 +250,17 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward, memory_cost_backward_activation = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@ -232,13 +292,18 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
memory_cost = numel * size_per_elem_bytes
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_backward_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
# This strategy do NOT need all_reduce during forward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@ -270,15 +335,17 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = 1
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
# This strategy do not need to do all_reduce during forward phase
communication_cost_forward = 0
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@ -310,12 +377,13 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
memory_cost = numel * size_per_elem_bytes
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_backward_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight)
# This strategy do not need to do all_reduce operation
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
@ -349,13 +417,14 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost = numel * size_per_elem_bytes / sharding_size
sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_backward_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight)
# This strategy do not need to do all_reduce operation
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
@ -390,13 +459,19 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
memory_cost = numel * size_per_elem_bytes
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_backward_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_backward_weight)
# compute communication cost
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
# compute communication cost during forward phase
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward, 0)
# This strategy do NOT need do all_reduce during backward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,

View File

@ -85,12 +85,17 @@ class OperatorHandler(ABC):
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
for input_node, target_spec in zip(self.predecessor_node, sharding_spec_for_input):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
# compute the resharding cost during forward phase
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency(
input_sharding_spec, target_spec)
# In backward phase, we should convert grad with target_spec into input_sharding_spec
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
target_spec, input_sharding_spec)
resharding_cost = resharding_cost_forward + resharding_cost_backward
resharding_costs[input_node].append(resharding_cost)
return resharding_costs

View File

@ -82,7 +82,6 @@ def test_conv_handler():
strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager)
conv_handler.register_strategy()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]