mirror of https://github.com/InternLM/InternLM
modify the all2all
parent
bf475b6940
commit
bd4af3a31f
|
@ -5,7 +5,7 @@ SEQ_LEN = 2048
|
||||||
HIDDEN_SIZE = 4096
|
HIDDEN_SIZE = 4096
|
||||||
NUM_ATTENTION_HEAD = 32
|
NUM_ATTENTION_HEAD = 32
|
||||||
MLP_RATIO = 8 / 3
|
MLP_RATIO = 8 / 3
|
||||||
NUM_LAYER = 4
|
NUM_LAYER = 32
|
||||||
VOCAB_SIZE = 103168
|
VOCAB_SIZE = 103168
|
||||||
|
|
||||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||||
|
|
|
@ -372,7 +372,6 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
|
|
||||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||||
# attention_mask: compute attention on the places where the value is 1
|
# attention_mask: compute attention on the places where the value is 1
|
||||||
import pdb; pdb.set_trace()
|
|
||||||
if hasattr(self, "embedding"):
|
if hasattr(self, "embedding"):
|
||||||
hidden_states = self.embedding(input_ids)
|
hidden_states = self.embedding(input_ids)
|
||||||
if self.embed_grad_scale != 1:
|
if self.embed_grad_scale != 1:
|
||||||
|
|
|
@ -115,14 +115,14 @@ class DistributedAttention(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
# TODO Merge three alltoall calls into one
|
# TODO Merge three alltoall calls into one
|
||||||
#in shape : e.g., [s/p:h:]
|
#in shape : e.g., [s/p:h:]
|
||||||
qkv = _SeqAllToAll.apply(self.spg, qkv, self.scatter_idx, self.gather_idx)
|
qkv = _SeqAllToAll.apply(self.spg, qkv, 2, 0)
|
||||||
# key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
# key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
||||||
# value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
|
# value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
|
||||||
|
|
||||||
#out shape : e.g., [s:h/p:]
|
#out shape : e.g., [s:h/p:]
|
||||||
context_layer = self.local_attn(qkv, **kwargs)
|
context_layer = self.local_attn(qkv, **kwargs)
|
||||||
|
|
||||||
output = _SeqAllToAll.apply(self.spg, context_layer, 0, 2)
|
output = _SeqAllToAll.apply(self.spg, context_layer, 0, 1)
|
||||||
|
|
||||||
#out e.g., [s/p::h]
|
#out e.g., [s/p::h]
|
||||||
return output
|
return output
|
||||||
|
|
Loading…
Reference in New Issue