[autoparallel] complete gpt block searching (#2065)

* [autoparallel] complete gpt block searching

* fix test
pull/2083/head
YuliangLiu0306 2022-12-06 10:17:10 +08:00 committed by GitHub
parent 597cdd3006
commit f123476666
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 60 deletions

View File

@ -12,6 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
@ -94,6 +95,7 @@ class LayerNormGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {
"input": dim_partition,
@ -151,6 +153,7 @@ class LayerNormGenerator(StrategyGenerator):
strategy_list.append(strategy)
return strategy_list
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
dim_partition_dict_mapping = {

View File

@ -14,6 +14,8 @@ __all__ = ['UnaryElementwiseHandler']
@operator_registry.register(torch.Tensor.type)
@operator_registry.register(torch.abs)
@operator_registry.register(torch.nn.ReLU)
@operator_registry.register(torch.nn.Tanh)
@operator_registry.register(torch.tanh)
# TODO: softmax need to be relocated
@operator_registry.register(torch.nn.functional.softmax)
@operator_registry.register(torch.nn.modules.dropout.Dropout)

View File

@ -254,8 +254,9 @@ class StrategiesVector(list):
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
merge_label = True
# TODO: remove this after we support the fall back logic.
# if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
# merge_label = True
# we could merge reshape op, because their computation costs are negligible.
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import transformers
from torch.fx import GraphModule
from torchvision.models import resnet50
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
@ -19,6 +19,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
BATCH_SIZE = 1
@ -33,7 +34,7 @@ HIDDEN_DIM = 768
# order is same as megatron-lm gpt model.
class GPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
def __init__(self, config, layer_idx=None):
super().__init__()
max_positions = config.max_position_embeddings
@ -48,24 +49,13 @@ class GPT2Attention(nn.Module):
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention
# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
if self.is_cross_attention:
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else:
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
@ -83,11 +73,10 @@ class GPT2Attention(nn.Module):
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
# Apply the attention mask
@ -108,17 +97,11 @@ class GPT2Attention(nn.Module):
return attn_output, attn_weights
def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
@ -126,41 +109,19 @@ class GPT2Attention(nn.Module):
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.")
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
# query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
qkv = self.c_attn(hidden_states)
# query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
qkv = self.c_attn(hidden_states)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)
query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value)
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
@ -172,12 +133,54 @@ class GPT2Attention(nn.Module):
return outputs # a, present, (attentions)
class GPT2Block(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
# %transformer_h_0_ln_1
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions, cross_attentions)
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_self_attention_block():
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
model_cls = GPT2Attention
model = model_cls(config=config)
# output = model(torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), attention_mask=torch.rand(1, SEQ_LENGTH))
if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
else:
model = model_cls(config=config)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
@ -186,10 +189,15 @@ def test_self_attention_block():
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
}
if model_cls == GPT2MLP:
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
}
else:
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
}
graph = tracer.trace(root=model, meta_args=input_sample)