mirror of https://github.com/hpcaitech/ColossalAI
[bugfix] colo attn bug fix
parent
e521890d32
commit
2d73efdfdd
|
@ -73,8 +73,8 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
||||
}
|
||||
|
||||
if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
|
||||
raise ValueError("No parameters found in dp_process_group or moe_dp_group")
|
||||
# if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
|
||||
# raise ValueError("No parameters found in dp_process_group or moe_dp_group")
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
@ -34,6 +36,8 @@ from colossalai.shardformer.shard import ShardConfig
|
|||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
|
||||
|
||||
from ..layer import ColoAttention
|
||||
|
||||
|
||||
# copied from modeling_deepseek.py
|
||||
class AddAuxiliaryLoss(torch.autograd.Function):
|
||||
|
@ -529,34 +533,30 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
bsz, q_len, _ = hidden_states.size() # 1 4, 32
|
||||
|
||||
# sp: modify sp_len when sequence parallel mode is ring
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
import torch.distributed as dist
|
||||
|
||||
rank = dist.get_rank()
|
||||
print(f"{rank=}, hidden states:{hidden_states.shape}")
|
||||
dist.get_rank()
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
rank = dist.get_rank()
|
||||
print(f"{rank=}, before all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
print(f"{rank=}, after all to all q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
print(f"{rank=}, after view to (b,s,h,d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
|
@ -565,7 +565,6 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids, unsqueeze_dim=0
|
||||
)
|
||||
print(f"{rank=}, after rope q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
|
@ -573,13 +572,11 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
print(
|
||||
f"{rank=}, after transpose to (b, nh, s, d) q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}"
|
||||
)
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# query_states = query_states.transpose(1, 2)
|
||||
# key_states = key_states.transpose(1, 2)
|
||||
# value_states = value_states.transpose(1, 2)
|
||||
# dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
|
@ -606,22 +603,57 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
print(f"{rank=}, before flash attn q:{query_states.shape}, k:{key_states.shape}, v:{value_states.shape}")
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||
)
|
||||
# attn_output = self._flash_attention_forward(
|
||||
# query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||
# )
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
# print(f"{rank=}, shard attn output after all to all:{attn_output[0][0]}")
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
# print(f"{rank=}, {attn_output[0][0]}")
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
dist.get_rank()
|
||||
# print(f"{rank=}, {attn_output[0][0]}")
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
return forward
|
||||
|
@ -683,24 +715,38 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
|||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (
|
||||
inputs_embeds.shape[0],
|
||||
1,
|
||||
past_key_values_length + inputs_embeds.shape[1],
|
||||
past_key_values_length + inputs_embeds.shape[1],
|
||||
)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
inputs_embeds.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
|
@ -714,7 +760,7 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
|||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
@ -746,8 +792,10 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
|||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
# import torch.distributed as dist
|
||||
# rank = dist.get_rank()
|
||||
# print(f"{rank=}, {hidden_states[0][0]}")
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
|
@ -194,11 +193,11 @@ class DeepseekPolicy(Policy):
|
|||
target_key="DeepseekModel",
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
warnings.warn(
|
||||
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
|
||||
)
|
||||
self.shard_config.enable_flash_attention = False
|
||||
# if self.shard_config.enable_flash_attention:
|
||||
# warnings.warn(
|
||||
# "Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
|
||||
# )
|
||||
# self.shard_config.enable_flash_attention = False
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ def init_deepseek():
|
|||
num_attention_heads=8,
|
||||
num_key_value_heads=8,
|
||||
# vocab_size=2200,
|
||||
first_k_dense_replace=1,
|
||||
first_k_dense_replace=2,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype="float16",
|
||||
n_routed_experts=8,
|
||||
|
@ -68,7 +68,6 @@ def init_deepseek():
|
|||
|
||||
if hasattr(config, "pad_token_id"):
|
||||
config.pad_token_id = config.eos_token_id
|
||||
print(config)
|
||||
model = transformers.AutoModel.from_config(config, trust_remote_code=True)
|
||||
|
||||
return model
|
||||
|
|
|
@ -30,7 +30,12 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
|||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
# TODO: SGD failed for full dp
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
|
||||
# model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
|
||||
model_fn,
|
||||
loss_fn,
|
||||
test_config,
|
||||
pluggin_cls=MoeHybridParallelPlugin,
|
||||
optim_class=torch.optim.SGD,
|
||||
)
|
||||
|
||||
org_model = org_model.to(torch.float16)
|
||||
|
@ -39,16 +44,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
)
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
rank = dist.get_rank()
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol, rtol)
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
mixtral_model = unwrap_model(org_model, "DeepseekModel", "model")
|
||||
|
@ -178,12 +182,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"sp_size": 2,
|
||||
"ep_size": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"enable_flash_attention": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"find_unused_parameters": True,
|
||||
# "find_unused_parameters": True,
|
||||
},
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
|
@ -224,7 +229,7 @@ def check_deepseek(rank, world_size, port):
|
|||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_mixtral():
|
||||
spawn(check_deepseek, 4)
|
||||
spawn(check_deepseek, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue