mirror of https://github.com/hpcaitech/ColossalAI
492 lines
20 KiB
Python
492 lines
20 KiB
Python
import operator
|
|
from functools import reduce
|
|
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
|
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
|
from .strategy_generator import StrategyGenerator_V2
|
|
from typing import List
|
|
from .._utils import exception_handler
|
|
import copy
|
|
|
|
|
|
class ConvStrategyGenerator(StrategyGenerator_V2):
|
|
"""
|
|
ConvStrategyGenerator is a generic class to generate strategies.
|
|
The operation data is defined as `output = input x other + bias`.
|
|
"""
|
|
|
|
@property
|
|
def has_bias(self):
|
|
return 'bias' in self.op_data
|
|
|
|
def validate(self) -> bool:
|
|
'''
|
|
In sanity check, we need make sure the input data having correct dimension size.
|
|
For Conv1d, the dim of input data should be 3([N, C, L]).
|
|
For Conv2d, the dim of input data should be 4([N, C, H, W]).
|
|
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
|
|
'''
|
|
input_op_data = self.op_data['input']
|
|
assert input_op_data.dim() in (3, 4,
|
|
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
|
|
|
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
|
|
'''
|
|
Compute the computation cost per device with this specific strategy.
|
|
|
|
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
|
'''
|
|
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
|
# 1D: (L) * N * Cout * Cin * kernel
|
|
# 2D: (H * W) * N * Cout * Cin * kernel
|
|
# 3D: (H * W * D) * N * Cout * Cin * kernel
|
|
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
|
|
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
|
|
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
|
|
if self.has_bias:
|
|
# bias add is an element wise operation, so the cost is equal to product of output shape.
|
|
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
|
|
|
|
output_size = sharded_output_shape[2:]
|
|
output_size_product = reduce(operator.mul, output_size)
|
|
input_size = sharded_input_shape[2:]
|
|
input_size_product = reduce(operator.mul, input_size, 1)
|
|
kernel_size = sharded_other_shape[2:]
|
|
kernel_size_product = reduce(operator.mul, kernel_size, 1)
|
|
batch_size = sharded_input_shape[0]
|
|
channel_in = sharded_input_shape[1]
|
|
channel_out = sharded_other_shape[1]
|
|
|
|
forward_compute_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product
|
|
|
|
backward_activation_cost = input_size_product * batch_size * channel_in * channel_out * kernel_size_product
|
|
backward_weight_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product
|
|
backward_compute_cost = backward_weight_cost + backward_activation_cost
|
|
if self.has_bias:
|
|
forward_compute_cost += bias_compute_cost
|
|
backward_compute_cost += bias_compute_cost
|
|
total_compute_cost = forward_compute_cost + backward_compute_cost
|
|
|
|
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
|
|
return compute_cost
|
|
|
|
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
|
forward_size_mapping = {
|
|
'input': self._compute_size_in_bytes(strategy, "input"),
|
|
'other': self._compute_size_in_bytes(strategy, "other"),
|
|
'output': self._compute_size_in_bytes(strategy, "output")
|
|
}
|
|
|
|
if self.has_bias:
|
|
bias_size = self._compute_size_in_bytes(strategy, "bias")
|
|
forward_size_mapping['bias'] = bias_size
|
|
|
|
backward_size_mapping = copy.deepcopy(forward_size_mapping)
|
|
backward_size_mapping.pop("output")
|
|
# compute fwd cost incurred
|
|
# fwd_cost = input + other + bias + output
|
|
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
|
|
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
|
|
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
|
|
|
|
# compute bwd cost incurred
|
|
# bwd_cost = input_grad + other_grad + bias_grad
|
|
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
|
|
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
|
|
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_activation_cost)
|
|
|
|
# compute total cost
|
|
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
|
|
parameter=fwd_parameter_cost + bwd_activation_cost)
|
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
|
strategy.memory_cost = memory_cost
|
|
|
|
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
|
|
|
dim_partition_dict_mapping = {
|
|
"input": {
|
|
0: [mesh_dim_0]
|
|
},
|
|
"other": {
|
|
1: [mesh_dim_1]
|
|
},
|
|
"output": {
|
|
0: [mesh_dim_0],
|
|
1: [mesh_dim_1]
|
|
},
|
|
}
|
|
if self.has_bias:
|
|
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]}
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
# set communication action
|
|
input_comm_spec = self.get_communication_spec(
|
|
sharding_spec=sharding_spec_mapping["input"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=mesh_dim_1)
|
|
|
|
communication_action_mapping = {"input": input_comm_spec}
|
|
|
|
if self.is_param("other"):
|
|
other_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["other"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=mesh_dim_0)
|
|
communication_action_mapping["other"] = other_comm_spec
|
|
|
|
if self.has_bias and self.is_param("bias"):
|
|
bias_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["bias"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=mesh_dim_0)
|
|
communication_action_mapping["bias"] = bias_comm_spec
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
def split_input_batch(self, mesh_dim_0):
|
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
|
|
|
dim_partition_dict_mapping = {
|
|
"input": {
|
|
0: [mesh_dim_0]
|
|
},
|
|
"other": {},
|
|
"output": {
|
|
0: [mesh_dim_0],
|
|
},
|
|
}
|
|
if self.has_bias:
|
|
dim_partition_dict_mapping["bias"] = {}
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
communication_action_mapping = {}
|
|
if self.is_param("other"):
|
|
other_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["other"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=mesh_dim_0)
|
|
communication_action_mapping["other"] = other_comm_spec
|
|
|
|
if self.has_bias and self.is_param("bias"):
|
|
bias_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["bias"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=mesh_dim_0)
|
|
communication_action_mapping["bias"] = bias_comm_spec
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
|
|
|
dim_partition_dict_mapping = {
|
|
"input": {
|
|
0: [mesh_dim_0],
|
|
1: [mesh_dim_1],
|
|
},
|
|
"other": {
|
|
0: [mesh_dim_1]
|
|
},
|
|
"output": {
|
|
0: [mesh_dim_0],
|
|
},
|
|
}
|
|
if self.has_bias:
|
|
dim_partition_dict_mapping["bias"] = {}
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
# set communication action
|
|
output_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["output"],
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
|
logical_process_axis=mesh_dim_1)
|
|
|
|
communication_action_mapping = {"output": output_comm_spec}
|
|
|
|
if self.is_param("other"):
|
|
other_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["other"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=mesh_dim_0)
|
|
communication_action_mapping["other"] = other_comm_spec
|
|
|
|
if self.has_bias and self.is_param("bias"):
|
|
bias_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["bias"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=mesh_dim_0)
|
|
communication_action_mapping["bias"] = bias_comm_spec
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
|
|
|
dim_partition_dict_mapping = {
|
|
"input": {
|
|
1: [mesh_dim_0],
|
|
},
|
|
"other": {
|
|
0: [mesh_dim_0],
|
|
1: [mesh_dim_1],
|
|
},
|
|
"output": {
|
|
1: [mesh_dim_1],
|
|
},
|
|
}
|
|
|
|
if self.has_bias:
|
|
dim_partition_dict_mapping["bias"] = {
|
|
0: [mesh_dim_1],
|
|
}
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
# set communication action
|
|
output_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["output"],
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
|
logical_process_axis=mesh_dim_0)
|
|
input_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["input"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=mesh_dim_0)
|
|
|
|
communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec}
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
|
|
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
|
|
|
dim_partition_dict_mapping = {
|
|
"input": {
|
|
1: [mesh_dim_0],
|
|
},
|
|
"other": {
|
|
0: [mesh_dim_0],
|
|
},
|
|
"output": {},
|
|
}
|
|
|
|
if self.has_bias:
|
|
dim_partition_dict_mapping["bias"] = {}
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
# set communication action
|
|
output_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["output"],
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
|
logical_process_axis=mesh_dim_0)
|
|
|
|
communication_action_mapping = {"output": output_comm_spec}
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
def split_weight_out_channel(self, mesh_dim_0):
|
|
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
|
|
|
dim_partition_dict_mapping = {
|
|
"input": {},
|
|
"other": {
|
|
1: [mesh_dim_0],
|
|
},
|
|
"output": {
|
|
1: [mesh_dim_0],
|
|
},
|
|
}
|
|
|
|
if self.has_bias:
|
|
dim_partition_dict_mapping["bias"] = {
|
|
0: [mesh_dim_0],
|
|
}
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
# set communication action
|
|
input_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["input"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=mesh_dim_0)
|
|
|
|
communication_action_mapping = {"input": input_comm_spec}
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
def non_split(self):
|
|
name = f'RR = RR x RR'
|
|
|
|
dim_partition_dict_mapping = {
|
|
"input": {},
|
|
"other": {},
|
|
"output": {},
|
|
}
|
|
|
|
if self.has_bias:
|
|
dim_partition_dict_mapping["bias"] = {}
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
communication_action_mapping={})
|
|
|
|
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
|
|
|
dim_partition_dict_mapping = {
|
|
"input": {
|
|
0: [mesh_dim_0, mesh_dim_1],
|
|
},
|
|
"other": {},
|
|
"output": {
|
|
0: [mesh_dim_0, mesh_dim_1],
|
|
},
|
|
}
|
|
|
|
if self.has_bias:
|
|
dim_partition_dict_mapping["bias"] = {}
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
communication_action_mapping = {}
|
|
if self.is_param("other"):
|
|
other_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["other"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
|
communication_action_mapping["other"] = other_comm_spec
|
|
|
|
if self.has_bias and self.is_param("bias"):
|
|
bias_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["bias"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
|
communication_action_mapping["bias"] = bias_comm_spec
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
|
dim_partition_dict_mapping = {
|
|
"input": {
|
|
1: [mesh_dim_0, mesh_dim_1],
|
|
},
|
|
"other": {
|
|
0: [mesh_dim_0, mesh_dim_1],
|
|
},
|
|
"output": {},
|
|
}
|
|
|
|
if self.has_bias:
|
|
dim_partition_dict_mapping["bias"] = {}
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
# set communication action
|
|
output_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["output"],
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
|
|
|
communication_action_mapping = {"output": output_comm_spec}
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
|
dim_partition_dict_mapping = {
|
|
"input": {},
|
|
"other": {
|
|
1: [mesh_dim_0, mesh_dim_1],
|
|
},
|
|
"output": {
|
|
1: [mesh_dim_0, mesh_dim_1],
|
|
},
|
|
}
|
|
|
|
if self.has_bias:
|
|
dim_partition_dict_mapping["bias"] = {
|
|
0: [mesh_dim_0, mesh_dim_1],
|
|
}
|
|
|
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
|
|
|
# set communication action
|
|
input_comm_spec = self.get_communication_spec(
|
|
sharding_spec_mapping["input"],
|
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
|
|
|
communication_action_mapping = {"input": input_comm_spec}
|
|
|
|
return self.get_sharding_strategy(name=name,
|
|
sharding_spec_mapping=sharding_spec_mapping,
|
|
communication_action_mapping=communication_action_mapping)
|
|
|
|
def generate(self) -> List[ShardingStrategy_V2]:
|
|
strategies = []
|
|
# SS = SR x RS
|
|
strategies.append(self.split_input_batch_weight_out_channel(0, 1))
|
|
strategies.append(self.split_input_batch_weight_out_channel(1, 0))
|
|
|
|
# SR = SR x RR
|
|
strategies.append(self.split_input_batch(0))
|
|
strategies.append(self.split_input_batch(1))
|
|
|
|
# SR = SS x SR
|
|
strategies.append(self.split_input_both_dim_weight_in_channel(0, 1))
|
|
strategies.append(self.split_input_both_dim_weight_in_channel(1, 0))
|
|
|
|
# RS = RS x SS
|
|
strategies.append(self.split_input_in_channel_weight_both_channel(0, 1))
|
|
strategies.append(self.split_input_in_channel_weight_both_channel(1, 0))
|
|
|
|
# RR = RS x SR
|
|
strategies.append(self.split_input_in_channel_weight_in_channel(0))
|
|
strategies.append(self.split_input_in_channel_weight_in_channel(1))
|
|
|
|
# RS = RR x RS
|
|
strategies.append(self.split_weight_out_channel(0))
|
|
strategies.append(self.split_weight_out_channel(1))
|
|
|
|
# RR= RR x RR
|
|
strategies.append(self.non_split())
|
|
|
|
# S01R = S01R x RR
|
|
strategies.append(self.split_1d_parallel_on_input_batch(0, 1))
|
|
|
|
# RR = RS01 x S01R
|
|
strategies.append(self.split_1d_parallel_on_in_channel(0, 1))
|
|
|
|
# RS01 = RR x RS01
|
|
strategies.append(self.split_1d_parallel_on_out_channel(0, 1))
|
|
|
|
# update mete info on cost
|
|
for strategy in strategies:
|
|
self.update_communication_cost(strategy)
|
|
self.update_compute_cost(strategy)
|
|
self.update_memory_cost(strategy)
|
|
|
|
return strategies
|