mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel]add backward cost info into strategies (#1524)
parent
1a3599410d
commit
0908d0fc61
|
@ -49,11 +49,59 @@ class ConvHandler(OperatorHandler):
|
||||||
# 3D: (H * W * D) * N * Cout * Cin * kernel
|
# 3D: (H * W * D) * N * Cout * Cin * kernel
|
||||||
output_size = self.output_data.shape[2:]
|
output_size = self.output_data.shape[2:]
|
||||||
output_size_product = reduce(operator.mul, output_size, 1)
|
output_size_product = reduce(operator.mul, output_size, 1)
|
||||||
|
input_size = self.input_data.shape[2:]
|
||||||
|
input_size_product = reduce(operator.mul, input_size, 1)
|
||||||
kernel_size = self.weight.shape[2:]
|
kernel_size = self.weight.shape[2:]
|
||||||
kernel_size_product = reduce(operator.mul, kernel_size, 1)
|
kernel_size_product = reduce(operator.mul, kernel_size, 1)
|
||||||
compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||||
|
backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||||
|
backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||||
|
compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
|
||||||
return compute_cost
|
return compute_cost
|
||||||
|
|
||||||
|
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation,
|
||||||
|
sharding_size_backward_weight):
|
||||||
|
'''
|
||||||
|
Compute the memory cost per device with this specific strategy.
|
||||||
|
|
||||||
|
Argument:
|
||||||
|
sharding_size_forward(int): The forward activation will be divided
|
||||||
|
into sharding_size_forward number partions.
|
||||||
|
sharding_size_backward_activation(int): The backward activation will
|
||||||
|
be divided into sharding_size_backward_activation number partions.
|
||||||
|
sharding_size_backward_weight(int): The backward weight will be divided
|
||||||
|
into sharding_size_backward_weight number partions.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
memory_cost(Tuple[float]): Memory cost per device with this
|
||||||
|
specific strategy, the first element of this tuple is forward
|
||||||
|
memory cost, and the second element of this tuple is backward
|
||||||
|
memory cost.
|
||||||
|
memory_cost_forward(float): Memory cost of forward activation per
|
||||||
|
device with this specific strategy.
|
||||||
|
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||||
|
per device with this specific strategy.
|
||||||
|
'''
|
||||||
|
# compute the memory cost of this strategy
|
||||||
|
dtype = self.input_data.dtype
|
||||||
|
numel_output = self.output_data.numel()
|
||||||
|
numel_input = self.input_data.numel()
|
||||||
|
numel_weight = self.weight.numel()
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
|
||||||
|
# forward memory_cost
|
||||||
|
memory_cost_forward = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||||
|
|
||||||
|
# backward memory_cost
|
||||||
|
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||||
|
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_backward_weight
|
||||||
|
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||||
|
|
||||||
|
# memory_cost pair
|
||||||
|
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||||
|
|
||||||
|
return memory_cost, memory_cost_forward, memory_cost_backward_activation
|
||||||
|
|
||||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||||
|
|
||||||
|
@ -76,14 +124,19 @@ class ConvHandler(OperatorHandler):
|
||||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
numel = self.output_data.numel()
|
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
|
||||||
|
|
||||||
|
# This strategy do not need to do all_reduce operation during forward
|
||||||
|
communication_cost_forward = 0
|
||||||
|
# compute the backward communication cost of this strategy
|
||||||
|
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||||
|
# total communication cost
|
||||||
|
communication_cost = communication_cost_forward + communication_cost_backward
|
||||||
|
|
||||||
# This strategy do not need to do all_reduce operation
|
|
||||||
communication_cost = 0
|
|
||||||
sharding_strategies = ShardingStrategy(name,
|
sharding_strategies = ShardingStrategy(name,
|
||||||
output_sharding_spec=sharding_spec_for_ouput,
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
compute_cost=compute_cost,
|
compute_cost=compute_cost,
|
||||||
|
@ -115,13 +168,13 @@ class ConvHandler(OperatorHandler):
|
||||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||||
numel = self.output_data.numel()
|
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
sharding_size_backward_weight = 1
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
sharding_size_backward_weight)
|
||||||
|
|
||||||
# This strategy do not need to do all_reduce operation
|
# This strategy do not need to do all_reduce operation in both forward and backward phase.
|
||||||
communication_cost = 0
|
communication_cost = 0
|
||||||
sharding_strategies = ShardingStrategy(name,
|
sharding_strategies = ShardingStrategy(name,
|
||||||
output_sharding_spec=sharding_spec_for_ouput,
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
|
@ -154,14 +207,18 @@ class ConvHandler(OperatorHandler):
|
||||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||||
numel = self.output_data.numel()
|
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
sharding_size_backward_activation,
|
||||||
|
sharding_size_backward_weight)
|
||||||
|
|
||||||
# compute the communication cost of this strategy
|
# compute the communication cost of this strategy during forward phase
|
||||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_1)
|
||||||
|
# This strategy do not need to do all_reduce operation during backward phase
|
||||||
|
communication_cost_backward = 0
|
||||||
|
communication_cost = communication_cost_forward + communication_cost_backward
|
||||||
sharding_strategies = ShardingStrategy(name,
|
sharding_strategies = ShardingStrategy(name,
|
||||||
output_sharding_spec=sharding_spec_for_ouput,
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
compute_cost=compute_cost,
|
compute_cost=compute_cost,
|
||||||
|
@ -193,14 +250,17 @@ class ConvHandler(OperatorHandler):
|
||||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
|
||||||
numel = self.output_data.numel()
|
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
memory_cost, memory_cost_forward, memory_cost_backward_activation = self._generate_memory_cost(
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
|
||||||
|
|
||||||
# compute the communication cost of this strategy
|
# compute the communication cost of this strategy during forward phase
|
||||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
|
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
|
||||||
|
# compute the communication cost of this strategy during backward phase
|
||||||
|
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||||
|
communication_cost = communication_cost_forward + communication_cost_backward
|
||||||
sharding_strategies = ShardingStrategy(name,
|
sharding_strategies = ShardingStrategy(name,
|
||||||
output_sharding_spec=sharding_spec_for_ouput,
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
compute_cost=compute_cost,
|
compute_cost=compute_cost,
|
||||||
|
@ -232,13 +292,18 @@ class ConvHandler(OperatorHandler):
|
||||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
sharding_size_forward = 1
|
||||||
numel = self.output_data.numel()
|
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
|
||||||
memory_cost = numel * size_per_elem_bytes
|
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
|
||||||
|
sharding_size_backward_activation,
|
||||||
|
sharding_size_backward_weight)
|
||||||
|
|
||||||
# compute the communication cost of this strategy
|
# compute the communication cost of this strategy during forward phase
|
||||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
|
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
|
||||||
|
# This strategy do NOT need all_reduce during forward phase
|
||||||
|
communication_cost_backward = 0
|
||||||
|
communication_cost = communication_cost_forward + communication_cost_backward
|
||||||
sharding_strategies = ShardingStrategy(name,
|
sharding_strategies = ShardingStrategy(name,
|
||||||
output_sharding_spec=sharding_spec_for_ouput,
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
compute_cost=compute_cost,
|
compute_cost=compute_cost,
|
||||||
|
@ -270,15 +335,17 @@ class ConvHandler(OperatorHandler):
|
||||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||||
numel = self.output_data.numel()
|
sharding_size_backward_activation = 1
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
|
||||||
|
|
||||||
# This strategy do not need to do all_reduce operation
|
|
||||||
communication_cost = 0
|
|
||||||
|
|
||||||
|
# This strategy do not need to do all_reduce during forward phase
|
||||||
|
communication_cost_forward = 0
|
||||||
|
# compute the communication cost of this strategy during backward phase
|
||||||
|
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
|
||||||
|
communication_cost = communication_cost_forward + communication_cost_backward
|
||||||
sharding_strategies = ShardingStrategy(name,
|
sharding_strategies = ShardingStrategy(name,
|
||||||
output_sharding_spec=sharding_spec_for_ouput,
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
compute_cost=compute_cost,
|
compute_cost=compute_cost,
|
||||||
|
@ -310,12 +377,13 @@ class ConvHandler(OperatorHandler):
|
||||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
sharding_size_forward = 1
|
||||||
numel = self.output_data.numel()
|
sharding_size_backward_activation = 1
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
sharding_size_backward_weight = 1
|
||||||
memory_cost = numel * size_per_elem_bytes
|
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||||
|
sharding_size_backward_weight)
|
||||||
|
|
||||||
# This strategy do not need to do all_reduce operation
|
# This strategy do not need to do all_reduce in both forward and backward phase
|
||||||
communication_cost = 0
|
communication_cost = 0
|
||||||
|
|
||||||
sharding_strategies = ShardingStrategy(name,
|
sharding_strategies = ShardingStrategy(name,
|
||||||
|
@ -349,13 +417,14 @@ class ConvHandler(OperatorHandler):
|
||||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||||
numel = self.output_data.numel()
|
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
mesh_dim_1]
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
sharding_size_backward_weight = 1
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||||
|
sharding_size_backward_weight)
|
||||||
|
|
||||||
# This strategy do not need to do all_reduce operation
|
# This strategy do not need to do all_reduce in both forward and backward phase
|
||||||
communication_cost = 0
|
communication_cost = 0
|
||||||
|
|
||||||
sharding_strategies = ShardingStrategy(name,
|
sharding_strategies = ShardingStrategy(name,
|
||||||
|
@ -390,13 +459,19 @@ class ConvHandler(OperatorHandler):
|
||||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
sharding_size_forward = 1
|
||||||
numel = self.output_data.numel()
|
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
mesh_dim_1]
|
||||||
memory_cost = numel * size_per_elem_bytes
|
sharding_size_backward_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||||
|
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
|
||||||
|
sharding_size_backward_activation,
|
||||||
|
sharding_size_backward_weight)
|
||||||
|
|
||||||
# compute communication cost
|
# compute communication cost during forward phase
|
||||||
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
|
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward, 0)
|
||||||
|
# This strategy do NOT need do all_reduce during backward phase
|
||||||
|
communication_cost_backward = 0
|
||||||
|
communication_cost = communication_cost_forward + communication_cost_backward
|
||||||
|
|
||||||
sharding_strategies = ShardingStrategy(name,
|
sharding_strategies = ShardingStrategy(name,
|
||||||
output_sharding_spec=sharding_spec_for_ouput,
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
|
|
|
@ -85,12 +85,17 @@ class OperatorHandler(ABC):
|
||||||
'''
|
'''
|
||||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||||
resharding_costs = {}
|
resharding_costs = {}
|
||||||
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
for input_node, target_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
||||||
resharding_costs[input_node] = []
|
resharding_costs[input_node] = []
|
||||||
for strategy in input_node.strategies_vector:
|
for strategy in input_node.strategies_vector:
|
||||||
input_sharding_spec = strategy.output_sharding_spec
|
input_sharding_spec = strategy.output_sharding_spec
|
||||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
|
# compute the resharding cost during forward phase
|
||||||
input_sharding_spec, input_spec)
|
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency(
|
||||||
|
input_sharding_spec, target_spec)
|
||||||
|
# In backward phase, we should convert grad with target_spec into input_sharding_spec
|
||||||
|
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
|
||||||
|
target_spec, input_sharding_spec)
|
||||||
|
resharding_cost = resharding_cost_forward + resharding_cost_backward
|
||||||
resharding_costs[input_node].append(resharding_cost)
|
resharding_costs[input_node].append(resharding_cost)
|
||||||
return resharding_costs
|
return resharding_costs
|
||||||
|
|
|
@ -82,7 +82,6 @@ def test_conv_handler():
|
||||||
strategies_vector=strategies_vector,
|
strategies_vector=strategies_vector,
|
||||||
shape_consistency_manager=shape_consistency_manager)
|
shape_consistency_manager=shape_consistency_manager)
|
||||||
conv_handler.register_strategy()
|
conv_handler.register_strategy()
|
||||||
|
|
||||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
|
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
|
||||||
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]
|
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue