mirror of https://github.com/InternLM/InternLM
modify the all2all
parent
bf475b6940
commit
bd4af3a31f
|
@ -5,7 +5,7 @@ SEQ_LEN = 2048
|
|||
HIDDEN_SIZE = 4096
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
NUM_LAYER = 4
|
||||
NUM_LAYER = 32
|
||||
VOCAB_SIZE = 103168
|
||||
|
||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||
|
|
|
@ -372,7 +372,6 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
|
||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||
# attention_mask: compute attention on the places where the value is 1
|
||||
import pdb; pdb.set_trace()
|
||||
if hasattr(self, "embedding"):
|
||||
hidden_states = self.embedding(input_ids)
|
||||
if self.embed_grad_scale != 1:
|
||||
|
|
|
@ -115,14 +115,14 @@ class DistributedAttention(torch.nn.Module):
|
|||
"""
|
||||
# TODO Merge three alltoall calls into one
|
||||
#in shape : e.g., [s/p:h:]
|
||||
qkv = _SeqAllToAll.apply(self.spg, qkv, self.scatter_idx, self.gather_idx)
|
||||
qkv = _SeqAllToAll.apply(self.spg, qkv, 2, 0)
|
||||
# key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
||||
# value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
|
||||
|
||||
#out shape : e.g., [s:h/p:]
|
||||
context_layer = self.local_attn(qkv, **kwargs)
|
||||
|
||||
output = _SeqAllToAll.apply(self.spg, context_layer, 0, 2)
|
||||
output = _SeqAllToAll.apply(self.spg, context_layer, 0, 1)
|
||||
|
||||
#out e.g., [s/p::h]
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue