Browse Source

[autoparallel] adapt solver with self attention (#2037)

* [autoparallel] adapt solver with self attention

* polish code
pull/2061/head
YuliangLiu0306 2 years ago committed by GitHub
parent
commit
1c1fe44305
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      colossalai/auto_parallel/tensor_shard/constants.py
  2. 18
      colossalai/auto_parallel/tensor_shard/sharding_strategy.py
  3. 32
      colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
  4. 6
      colossalai/auto_parallel/tensor_shard/solver/solver.py
  5. 38
      colossalai/auto_parallel/tensor_shard/utils/reshape.py
  6. 230
      tests/test_auto_parallel/test_tensor_shard/test_solver_self_attention_block.py

9
colossalai/auto_parallel/tensor_shard/constants.py

@ -26,7 +26,14 @@ ELEMENTWISE_METHOD_OP = [
# TODO: contiguous maybe need some extra processes.
torch.Tensor.contiguous
]
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
RESHAPE_FUNC_OP = [
torch.flatten,
torch.reshape,
torch.transpose,
torch.split,
torch.permute,
operator.getitem,
]
RESHAPE_METHOD_OP = [
torch.Tensor.view,
torch.Tensor.unsqueeze,

18
colossalai/auto_parallel/tensor_shard/sharding_strategy.py

@ -9,7 +9,14 @@ from torch.fx.node import Node
from colossalai.tensor.shape_consistency import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP
from .constants import (
BCAST_FUNC_OP,
ELEMENTWISE_FUNC_OP,
ELEMENTWISE_METHOD_OP,
ELEMENTWISE_MODULE_OP,
RESHAPE_FUNC_OP,
RESHAPE_METHOD_OP,
)
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
@ -249,8 +256,15 @@ class StrategiesVector(list):
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
merge_label = True
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
# we could merge reshape op, because their computation costs are negligible.
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True
if self.node.op == 'call_method':
# we could merge reshape op, because their computation costs are negligible.
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
if method in RESHAPE_METHOD_OP:
merge_label = True
if method in ELEMENTWISE_METHOD_OP:
merge_label = True
return merge_label

32
colossalai/auto_parallel/tensor_shard/solver/cost_graph.py

@ -63,14 +63,40 @@ class CostGraph:
edge_cost[(j, i)] = resharding_cost_item.total
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
parent_nodes = [node for node in strategies_vector.predecessor_nodes]
children_nodes = [node for node in strategies_vector.successor_nodes]
# parent_nodes = [node for node in strategies_vector.predecessor_nodes]
# children_nodes = [node for node in strategies_vector.successor_nodes]
parent_nodes = []
children_nodes = []
def _check_tensor_in_node(data):
"""
This method is used to check whether the data has a tensor inside or not.
"""
has_tensor_flag = False
if isinstance(data, torch.Tensor):
return True
elif isinstance(data, (tuple, list)):
for d in data:
has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d)
return has_tensor_flag
for node in strategies_vector.predecessor_nodes:
if _check_tensor_in_node(node._meta_data):
parent_nodes.append(node)
for node in strategies_vector.successor_nodes:
if _check_tensor_in_node(node._meta_data):
children_nodes.append(node)
setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes)
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
self.merge_pair.append((followed_node, dst_node))
# we only merge node pairs which src node has a tensor element inside.
# This is necessay because the node without a tensor element inside will not
# be assigned any strategy.
if _check_tensor_in_node(followed_node._meta_data):
self.merge_pair.append((followed_node, dst_node))
def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]

6
colossalai/auto_parallel/tensor_shard/solver/solver.py

@ -154,12 +154,16 @@ class Solver:
if self.forward_only:
origin_communication_cost = communication_cost_item.fwd
compute_cost = compute_cost_item.fwd
# extract MemoryCost item from the memory TrainCycleItem
memory_cost = memory_cost_item.fwd
else:
origin_communication_cost = communication_cost_item.total
compute_cost = compute_cost_item.total
# extract MemoryCost item from the memory TrainCycleItem
memory_cost = memory_cost_item.total
# extract the memory cost in float from MemoryCost item and sum them up
memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer
compute_costs.append(compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
@ -366,6 +370,8 @@ class Solver:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
if live_variable.node not in self.node_index_dict:
continue
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget

38
colossalai/auto_parallel/tensor_shard/utils/reshape.py

@ -53,17 +53,38 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
while origin_index != len(origin_shape) or tgt_index != len(tgt_shape):
if original_dimension_size == tgt_dimension_size:
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
origin_index += 1
tgt_index += 1
# if the origin_dims has no element, it means the original tensor has been fully matched.
# Therefore, we do not have to increase the origin_index for that case.
if len(origin_dims) > 0:
origin_index += 1
# if the tgt_dims has no element, it means the original tensor has been fully matched.
# Therefore, we do not have to increase the tgt_index for that case.
if len(tgt_dims) > 0:
tgt_index += 1
# the last step of loop should always end with condition
# so we need to manually skip the preparation for next step
# in the last step.
if origin_index == len(origin_shape):
if origin_index == len(origin_shape) and tgt_index == len(tgt_shape):
continue
original_dimension_size = origin_shape[origin_index]
tgt_dimension_size = tgt_shape[tgt_index]
origin_dims = [origin_len - origin_index - 1]
tgt_dims = [tgt_len - tgt_index - 1]
# If origin_index equals to origin_len, we just need to set the original_dimension_size
# to 1 to match the remaining '1's in the target tensor shape.
if origin_index == len(origin_shape):
original_dimension_size = 1
origin_dims = []
else:
original_dimension_size = origin_shape[origin_index]
origin_dims = [origin_len - origin_index - 1]
# If tgt_index equals to tgt_len, we just need to set the tgt_dimension_size
# to 1 to match the remaining '1's in the original tensor shape.
if tgt_index == len(tgt_shape):
tgt_dimension_size = 1
tgt_dims = []
else:
tgt_dimension_size = tgt_shape[tgt_index]
tgt_dims = [tgt_len - tgt_index - 1]
previous_label = PreviousStatus.RESET
elif original_dimension_size > tgt_dimension_size:
@ -141,6 +162,9 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
"""
sharded_dims = list(input_dim_partition_dict.keys())
for input_dims in reshape_mapping_dict.keys():
# if input_dims has no element, we could just skip this iteration.
if len(input_dims) == 0:
continue
min_element = min(input_dims)
for dim in input_dims:
if dim in sharded_dims and dim is not min_element:

230
tests/test_auto_parallel/test_tensor_shard/test_solver_self_attention_block.py

@ -0,0 +1,230 @@
from typing import Optional, Tuple, Union
import torch
# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
import torch.nn as nn
import transformers
from torch.fx import GraphModule
from torchvision.models import resnet50
from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing.pytest_wrapper import run_on_environment_flag
BATCH_SIZE = 1
SEQ_LENGTH = 32
HIDDEN_DIM = 768
# The reason Why we don't import GPT2Attention from transformers directly is that:
# 1. The tracer will not work correctly when we feed meta_args and concrete_args at same time,
# so we have to build the customized GPT2Attention class and remove the conditional branch manually.
# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new
# order is same as megatron-lm gpt model.
class GPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions),
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention
# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
if self.is_cross_attention:
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else:
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = set()
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / (value.size(-1)**0.5)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.")
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
# query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
qkv = self.c_attn(hidden_states)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)
query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value)
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
outputs += (attn_weights,)
return outputs # a, present, (attentions)
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_self_attention_block():
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
model_cls = GPT2Attention
model = model_cls(config=config)
# output = model(torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), attention_mask=torch.rand(1, SEQ_LENGTH))
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
print(gm.graph)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
ret = solver.call_solver_serialized_args()
strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
computation_cost = 0
communication_cost = 0
memory_cost = 0
for index, node in enumerate(nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost.activation + node_memory_cost.parameter
print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
if __name__ == '__main__':
test_self_attention_block()
Loading…
Cancel
Save