mirror of https://github.com/hpcaitech/ColossalAI
[moe] update capacity computing (#5253)
* [moe] top2 allow uneven input * [moe] update capacity computing * [moe] remove debug info * [moe] update capacity computing * [moe] update capacity computingpull/5372/head
parent
7d8e0338a4
commit
c904d2ae99
|
@ -126,12 +126,15 @@ def main():
|
||||||
load_model(args.model_name, model, booster)
|
load_model(args.model_name, model, booster)
|
||||||
coordinator.print_on_master(f"Finish load ckpt")
|
coordinator.print_on_master(f"Finish load ckpt")
|
||||||
|
|
||||||
text = ["Hello my name is", "1+1=?"]
|
if coordinator.rank == 0:
|
||||||
|
text = ["Hello my name is"]
|
||||||
|
else:
|
||||||
|
text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
|
||||||
tokenizer.pad_token = tokenizer.unk_token
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
|
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
|
||||||
outputs = model.module.generate(**inputs, max_new_tokens=20)
|
outputs = model.module.generate(**inputs, max_new_tokens=20)
|
||||||
outputs = tokenizer.batch_decode(outputs)[0]
|
outputs = tokenizer.batch_decode(outputs)
|
||||||
print(outputs)
|
print(f"[{coordinator.rank}] {outputs}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC):
|
||||||
self._z_loss = None
|
self._z_loss = None
|
||||||
self.use_kernel = use_kernel
|
self.use_kernel = use_kernel
|
||||||
|
|
||||||
def get_capacity(self, logits_shape):
|
def get_capacity(self, num_tokens, num_experts, ep_group=None):
|
||||||
|
if ep_group is not None:
|
||||||
|
num_tokens_tensor = torch.tensor(num_tokens, device=get_current_device())
|
||||||
|
dist.all_reduce(num_tokens_tensor, group=ep_group)
|
||||||
|
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
|
||||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||||
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
|
capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
|
||||||
capacity += capacity % 2
|
capacity += capacity % 2
|
||||||
capacity = max(capacity, self.min_capacity)
|
capacity = max(capacity, self.min_capacity)
|
||||||
assert capacity > 0
|
assert capacity > 0
|
||||||
|
@ -175,7 +179,8 @@ class Top1Router(MoeRouter):
|
||||||
assert inputs.dtype == torch.float
|
assert inputs.dtype == torch.float
|
||||||
probs = F.softmax(inputs, dim=-1)
|
probs = F.softmax(inputs, dim=-1)
|
||||||
num_experts = probs.size(-1)
|
num_experts = probs.size(-1)
|
||||||
capacity = self.get_capacity(inputs.shape)
|
num_tokens = inputs.size(0)
|
||||||
|
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
||||||
|
|
||||||
top1_idx = torch.argmax(inputs, dim=-1)
|
top1_idx = torch.argmax(inputs, dim=-1)
|
||||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||||
|
@ -276,7 +281,8 @@ class Top2Router(MoeRouter):
|
||||||
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
|
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
num_experts = probs.size(-1)
|
num_experts = probs.size(-1)
|
||||||
capacity = self.get_capacity(inputs.shape)
|
num_tokens = inputs.size(0)
|
||||||
|
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
||||||
|
|
||||||
top1_idx = torch.argmax(probs, dim=-1)
|
top1_idx = torch.argmax(probs, dim=-1)
|
||||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||||
|
|
Loading…
Reference in New Issue