modify the all2all

pull/407/head
yingtongxiong 2023-10-08 17:21:17 +08:00
parent bf475b6940
commit bd4af3a31f
3 changed files with 3 additions and 4 deletions

View File

@ -5,7 +5,7 @@ SEQ_LEN = 2048
HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
NUM_LAYER = 4
NUM_LAYER = 32
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"

View File

@ -372,7 +372,6 @@ class PackedFlashInternLm1D(nn.Module):
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
# attention_mask: compute attention on the places where the value is 1
import pdb; pdb.set_trace()
if hasattr(self, "embedding"):
hidden_states = self.embedding(input_ids)
if self.embed_grad_scale != 1:

View File

@ -115,14 +115,14 @@ class DistributedAttention(torch.nn.Module):
"""
# TODO Merge three alltoall calls into one
#in shape : e.g., [s/p:h:]
qkv = _SeqAllToAll.apply(self.spg, qkv, self.scatter_idx, self.gather_idx)
qkv = _SeqAllToAll.apply(self.spg, qkv, 2, 0)
# key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
# value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
#out shape : e.g., [s:h/p:]
context_layer = self.local_attn(qkv, **kwargs)
output = _SeqAllToAll.apply(self.spg, context_layer, 0, 2)
output = _SeqAllToAll.apply(self.spg, context_layer, 0, 1)
#out e.g., [s/p::h]
return output