diff --git a/colossalai/auto_parallel/tensor_shard/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py index 9143ad9db..99c124934 100644 --- a/colossalai/auto_parallel/tensor_shard/constants.py +++ b/colossalai/auto_parallel/tensor_shard/constants.py @@ -26,7 +26,14 @@ ELEMENTWISE_METHOD_OP = [ # TODO: contiguous maybe need some extra processes. torch.Tensor.contiguous ] -RESHAPE_FUNC_OP = [torch.flatten, torch.reshape] +RESHAPE_FUNC_OP = [ + torch.flatten, + torch.reshape, + torch.transpose, + torch.split, + torch.permute, + operator.getitem, +] RESHAPE_METHOD_OP = [ torch.Tensor.view, torch.Tensor.unsqueeze, diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index bbf4215d9..b758e1e09 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -9,7 +9,14 @@ from torch.fx.node import Node from colossalai.tensor.shape_consistency import CommSpec from colossalai.tensor.sharding_spec import ShardingSpec -from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP +from .constants import ( + BCAST_FUNC_OP, + ELEMENTWISE_FUNC_OP, + ELEMENTWISE_METHOD_OP, + ELEMENTWISE_MODULE_OP, + RESHAPE_FUNC_OP, + RESHAPE_METHOD_OP, +) __all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] @@ -249,8 +256,15 @@ class StrategiesVector(list): # we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case. if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1: merge_label = True - # we could merge reshape op, because the output sharding spec of reshape op is always fully replicated. + # we could merge reshape op, because their computation costs are negligible. if self.node.target in RESHAPE_FUNC_OP: merge_label = True + if self.node.op == 'call_method': + # we could merge reshape op, because their computation costs are negligible. + method = getattr(self.node.args[0]._meta_data.__class__, self.node.target) + if method in RESHAPE_METHOD_OP: + merge_label = True + if method in ELEMENTWISE_METHOD_OP: + merge_label = True return merge_label diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py index f1509af56..038e56547 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py +++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py @@ -63,14 +63,40 @@ class CostGraph: edge_cost[(j, i)] = resharding_cost_item.total self.edge_costs[node_pair] = edge_cost # add parents and children attribute to node - parent_nodes = [node for node in strategies_vector.predecessor_nodes] - children_nodes = [node for node in strategies_vector.successor_nodes] + # parent_nodes = [node for node in strategies_vector.predecessor_nodes] + # children_nodes = [node for node in strategies_vector.successor_nodes] + parent_nodes = [] + children_nodes = [] + + def _check_tensor_in_node(data): + """ + This method is used to check whether the data has a tensor inside or not. + """ + has_tensor_flag = False + if isinstance(data, torch.Tensor): + return True + elif isinstance(data, (tuple, list)): + for d in data: + has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d) + return has_tensor_flag + + for node in strategies_vector.predecessor_nodes: + if _check_tensor_in_node(node._meta_data): + parent_nodes.append(node) + for node in strategies_vector.successor_nodes: + if _check_tensor_in_node(node._meta_data): + children_nodes.append(node) + setattr(dst_node, 'parents', parent_nodes) setattr(dst_node, 'children', children_nodes) if self.simplify and strategies_vector.check_merge(): for followed_node in strategies_vector.predecessor_nodes: - self.merge_pair.append((followed_node, dst_node)) + # we only merge node pairs which src node has a tensor element inside. + # This is necessay because the node without a tensor element inside will not + # be assigned any strategy. + if _check_tensor_in_node(followed_node._meta_data): + self.merge_pair.append((followed_node, dst_node)) def get_edge_cost(self, src_node, dst_node): return self.edge_costs[(src_node, dst_node)] diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py index 7f972884e..89d0da223 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/solver.py +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -154,12 +154,16 @@ class Solver: if self.forward_only: origin_communication_cost = communication_cost_item.fwd compute_cost = compute_cost_item.fwd + # extract MemoryCost item from the memory TrainCycleItem memory_cost = memory_cost_item.fwd else: origin_communication_cost = communication_cost_item.total compute_cost = compute_cost_item.total + # extract MemoryCost item from the memory TrainCycleItem memory_cost = memory_cost_item.total + # extract the memory cost in float from MemoryCost item and sum them up + memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer compute_costs.append(compute_cost) # node in extra_node_costs means it has some extra communication # cost from node merging, so we need to add those extra communication @@ -366,6 +370,8 @@ class Solver: for liveness_stage in liveness_set: mem = 0 for live_variable in liveness_stage.unique_live_vars: + if live_variable.node not in self.node_index_dict: + continue node_index = self.node_index_dict[live_variable.node] mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) prob += mem <= memory_budget diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py index 8e02544f7..a32a14bf7 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/reshape.py +++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py @@ -53,17 +53,38 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D while origin_index != len(origin_shape) or tgt_index != len(tgt_shape): if original_dimension_size == tgt_dimension_size: reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) - origin_index += 1 - tgt_index += 1 + # if the origin_dims has no element, it means the original tensor has been fully matched. + # Therefore, we do not have to increase the origin_index for that case. + if len(origin_dims) > 0: + origin_index += 1 + # if the tgt_dims has no element, it means the original tensor has been fully matched. + # Therefore, we do not have to increase the tgt_index for that case. + if len(tgt_dims) > 0: + tgt_index += 1 # the last step of loop should always end with condition # so we need to manually skip the preparation for next step # in the last step. - if origin_index == len(origin_shape): + if origin_index == len(origin_shape) and tgt_index == len(tgt_shape): continue - original_dimension_size = origin_shape[origin_index] - tgt_dimension_size = tgt_shape[tgt_index] - origin_dims = [origin_len - origin_index - 1] - tgt_dims = [tgt_len - tgt_index - 1] + + # If origin_index equals to origin_len, we just need to set the original_dimension_size + # to 1 to match the remaining '1's in the target tensor shape. + if origin_index == len(origin_shape): + original_dimension_size = 1 + origin_dims = [] + else: + original_dimension_size = origin_shape[origin_index] + origin_dims = [origin_len - origin_index - 1] + + # If tgt_index equals to tgt_len, we just need to set the tgt_dimension_size + # to 1 to match the remaining '1's in the original tensor shape. + if tgt_index == len(tgt_shape): + tgt_dimension_size = 1 + tgt_dims = [] + else: + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1] + previous_label = PreviousStatus.RESET elif original_dimension_size > tgt_dimension_size: @@ -141,6 +162,9 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], """ sharded_dims = list(input_dim_partition_dict.keys()) for input_dims in reshape_mapping_dict.keys(): + # if input_dims has no element, we could just skip this iteration. + if len(input_dims) == 0: + continue min_element = min(input_dims) for dim in input_dims: if dim in sharded_dims and dim is not min_element: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_self_attention_block.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_self_attention_block.py new file mode 100644 index 000000000..7a1524966 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_self_attention_block.py @@ -0,0 +1,230 @@ +from typing import Optional, Tuple, Union + +import torch +# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention +import torch.nn as nn +import transformers +from torch.fx import GraphModule +from torchvision.models import resnet50 +from transformers.pytorch_utils import Conv1D + +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.testing.pytest_wrapper import run_on_environment_flag + +BATCH_SIZE = 1 +SEQ_LENGTH = 32 +HIDDEN_DIM = 768 + + +# The reason Why we don't import GPT2Attention from transformers directly is that: +# 1. The tracer will not work correctly when we feed meta_args and concrete_args at same time, +# so we have to build the customized GPT2Attention class and remove the conditional branch manually. +# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new +# order is same as megatron-lm gpt model. +class GPT2Attention(nn.Module): + + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), + dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads = set() + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / (value.size(-1)**0.5) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.") + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + # query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + qkv = self.c_attn(hidden_states) + + # query = self._split_heads(query, self.num_heads, self.head_dim) + # key = self._split_heads(key, self.num_heads, self.head_dim) + # value = self._split_heads(value, self.num_heads, self.head_dim) + query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = (key, value) + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +def test_self_attention_block(): + config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) + model_cls = GPT2Attention + model = model_cls(config=config) + # output = model(torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), attention_mask=torch.rand(1, SEQ_LENGTH)) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + input_sample = { + 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + 'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'), + } + + graph = tracer.trace(root=model, meta_args=input_sample) + + gm = GraphModule(model, graph, model.__class__.__name__) + print(gm.graph) + gm.recompile() + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) + ret = solver.call_solver_serialized_args() + strategies_list = solver.last_s_val + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + for index, node in enumerate(nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +if __name__ == '__main__': + test_self_attention_block()