[bugfix] colo attn bug fix

moe_sp
haze188 2024-07-24 06:53:24 +00:00
parent e521890d32
commit 2d73efdfdd
5 changed files with 106 additions and 55 deletions

View File

@ -73,8 +73,8 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}
if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
raise ValueError("No parameters found in dp_process_group or moe_dp_group")
# if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
# raise ValueError("No parameters found in dp_process_group or moe_dp_group")
super().__init__(
model=model,

View File

@ -1,7 +1,9 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
@ -34,6 +36,8 @@ from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
from ..layer import ColoAttention
# copied from modeling_deepseek.py
class AddAuxiliaryLoss(torch.autograd.Function):
@ -529,34 +533,30 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
output_attentions = False
bsz, q_len, _ = hidden_states.size()
bsz, q_len, _ = hidden_states.size() # 1 4, 32
# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
import torch.distributed as dist
rank = dist.get_rank()
print(f"{rank=}, hidden states:{hidden_states.shape}")
dist.get_rank()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
rank = dist.get_rank()
print(f"{rank=}, before all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
bsz, q_len, _ = query_states.size()
print(f"{rank=}, after all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
print(f"{rank=}, after view to (b,s,h,d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
@ -565,7 +565,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0
)
print(f"{rank=}, after rope q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
@ -573,13 +572,11 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
print(
f"{rank=}, after transpose to (b, nh, s, d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}"
)
dropout_rate = self.attention_dropout if self.training else 0.0
# query_states = query_states.transpose(1, 2)
# key_states = key_states.transpose(1, 2)
# value_states = value_states.transpose(1, 2)
# dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
@ -606,22 +603,57 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
print(f"{rank=}, before flash attn q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
# attn_output = self._flash_attention_forward(
# query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
# )
if shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
# print(f"{rank=}, shard attn output after all to all:{attn_output[0][0]}")
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
# print(f"{rank=}, {attn_output[0][0]}")
if not output_attentions:
attn_weights = None
import torch.distributed as dist
dist.get_rank()
# print(f"{rank=}, {attn_output[0][0]}")
return attn_output, attn_weights, past_key_value
return forward
@ -683,24 +715,38 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
if shard_config.enable_flash_attention:
mask_shape = (
inputs_embeds.shape[0],
1,
past_key_values_length + inputs_embeds.shape[1],
past_key_values_length + inputs_embeds.shape[1],
)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
inputs_embeds.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
@ -714,7 +760,7 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
@ -746,8 +792,10 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
if output_attentions:
all_self_attns += (layer_outputs[1],)
# import torch.distributed as dist
# rank = dist.get_rank()
# print(f"{rank=}, {hidden_states[0][0]}")
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":

View File

@ -1,4 +1,3 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Union
@ -194,11 +193,11 @@ class DeepseekPolicy(Policy):
target_key="DeepseekModel",
)
if self.shard_config.enable_flash_attention:
warnings.warn(
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
)
self.shard_config.enable_flash_attention = False
# if self.shard_config.enable_flash_attention:
# warnings.warn(
# "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
# )
# self.shard_config.enable_flash_attention = False
return policy

View File

@ -59,7 +59,7 @@ def init_deepseek():
num_attention_heads=8,
num_key_value_heads=8,
# vocab_size=2200,
first_k_dense_replace=1,
first_k_dense_replace=2,
attn_implementation="flash_attention_2",
torch_dtype="float16",
n_routed_experts=8,
@ -68,7 +68,6 @@ def init_deepseek():
if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
print(config)
model = transformers.AutoModel.from_config(config, trust_remote_code=True)
return model

View File

@ -30,7 +30,12 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# TODO: SGD failed for full dp
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
# model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
model_fn,
loss_fn,
test_config,
pluggin_cls=MoeHybridParallelPlugin,
optim_class=torch.optim.SGD,
)
org_model = org_model.to(torch.float16)
@ -39,16 +44,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
rank = dist.get_rank()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
mixtral_model = unwrap_model(org_model, "DeepseekModel", "model")
@ -178,12 +182,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"sp_size": 2,
"ep_size": 2,
"enable_sequence_parallelism": True,
"enable_flash_attention": True,
"sequence_parallelism_mode": "all_to_all",
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp16",
"initial_scale": 1,
"find_unused_parameters": True,
# "find_unused_parameters": True,
},
# {
# "tp_size": 1,
@ -224,7 +229,7 @@ def check_deepseek(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_mixtral():
spawn(check_deepseek, 4)
spawn(check_deepseek, 2)
if __name__ == "__main__":