mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel]add backward cost info into strategies (#1524)
parent
1a3599410d
commit
0908d0fc61
|
@ -49,11 +49,59 @@ class ConvHandler(OperatorHandler):
|
|||
# 3D: (H * W * D) * N * Cout * Cin * kernel
|
||||
output_size = self.output_data.shape[2:]
|
||||
output_size_product = reduce(operator.mul, output_size, 1)
|
||||
input_size = self.input_data.shape[2:]
|
||||
input_size_product = reduce(operator.mul, input_size, 1)
|
||||
kernel_size = self.weight.shape[2:]
|
||||
kernel_size_product = reduce(operator.mul, kernel_size, 1)
|
||||
compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||
forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||
backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||
backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||
compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_backward_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_backward_weight(int): The backward weight will be divided
|
||||
into sharding_size_backward_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
numel_input = self.input_data.numel()
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_backward_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward, memory_cost_backward_activation
|
||||
|
||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
|
@ -76,14 +124,19 @@ class ConvHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation during forward
|
||||
communication_cost_forward = 0
|
||||
# compute the backward communication cost of this strategy
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
# total communication cost
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
|
@ -115,13 +168,13 @@ class ConvHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
# This strategy do not need to do all_reduce operation in both forward and backward phase.
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
|
@ -154,14 +207,18 @@ class ConvHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_1)
|
||||
# This strategy do not need to do all_reduce operation during backward phase
|
||||
communication_cost_backward = 0
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
|
@ -193,14 +250,17 @@ class ConvHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
|
@ -232,13 +292,18 @@ class ConvHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
memory_cost = numel * size_per_elem_bytes
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
|
||||
# This strategy do NOT need all_reduce during forward phase
|
||||
communication_cost_backward = 0
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
|
@ -270,15 +335,17 @@ class ConvHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce during forward phase
|
||||
communication_cost_forward = 0
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
|
@ -310,12 +377,13 @@ class ConvHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
memory_cost = numel * size_per_elem_bytes
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_backward_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
# This strategy do not need to do all_reduce in both forward and backward phase
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
|
@ -349,13 +417,14 @@ class ConvHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||
mesh_dim_1]
|
||||
sharding_size_backward_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
# This strategy do not need to do all_reduce in both forward and backward phase
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
|
@ -390,13 +459,19 @@ class ConvHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
memory_cost = numel * size_per_elem_bytes
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||
mesh_dim_1]
|
||||
sharding_size_backward_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
|
||||
# compute communication cost
|
||||
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
|
||||
# compute communication cost during forward phase
|
||||
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward, 0)
|
||||
# This strategy do NOT need do all_reduce during backward phase
|
||||
communication_cost_backward = 0
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
|
|
|
@ -85,12 +85,17 @@ class OperatorHandler(ABC):
|
|||
'''
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs = {}
|
||||
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
||||
for input_node, target_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
# compute the resharding cost during forward phase
|
||||
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, target_spec)
|
||||
# In backward phase, we should convert grad with target_spec into input_sharding_spec
|
||||
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
|
||||
target_spec, input_sharding_spec)
|
||||
resharding_cost = resharding_cost_forward + resharding_cost_backward
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
|
|
@ -82,7 +82,6 @@ def test_conv_handler():
|
|||
strategies_vector=strategies_vector,
|
||||
shape_consistency_manager=shape_consistency_manager)
|
||||
conv_handler.register_strategy()
|
||||
|
||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
|
||||
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]
|
||||
|
||||
|
|
Loading…
Reference in New Issue