From c904d2ae997b161a5c6ddbf2057a7e194472c525 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 11 Jan 2024 16:09:38 +0800 Subject: [PATCH] [moe] update capacity computing (#5253) * [moe] top2 allow uneven input * [moe] update capacity computing * [moe] remove debug info * [moe] update capacity computing * [moe] update capacity computing --- applications/ColossalMoE/infer.py | 9 ++++++--- colossalai/moe/routers.py | 14 ++++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py index 70ddff940..d234fb628 100644 --- a/applications/ColossalMoE/infer.py +++ b/applications/ColossalMoE/infer.py @@ -126,12 +126,15 @@ def main(): load_model(args.model_name, model, booster) coordinator.print_on_master(f"Finish load ckpt") - text = ["Hello my name is", "1+1=?"] + if coordinator.rank == 0: + text = ["Hello my name is"] + else: + text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"] tokenizer.pad_token = tokenizer.unk_token inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device()) outputs = model.module.generate(**inputs, max_new_tokens=20) - outputs = tokenizer.batch_decode(outputs)[0] - print(outputs) + outputs = tokenizer.batch_decode(outputs) + print(f"[{coordinator.rank}] {outputs}") if __name__ == "__main__": diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 5c7d06656..4d99e48d3 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC): self._z_loss = None self.use_kernel = use_kernel - def get_capacity(self, logits_shape): + def get_capacity(self, num_tokens, num_experts, ep_group=None): + if ep_group is not None: + num_tokens_tensor = torch.tensor(num_tokens, device=get_current_device()) + dist.all_reduce(num_tokens_tensor, group=ep_group) + num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group) capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts) capacity += capacity % 2 capacity = max(capacity, self.min_capacity) assert capacity > 0 @@ -175,7 +179,8 @@ class Top1Router(MoeRouter): assert inputs.dtype == torch.float probs = F.softmax(inputs, dim=-1) num_experts = probs.size(-1) - capacity = self.get_capacity(inputs.shape) + num_tokens = inputs.size(0) + capacity = self.get_capacity(num_tokens, num_experts, ep_group) top1_idx = torch.argmax(inputs, dim=-1) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) @@ -276,7 +281,8 @@ class Top2Router(MoeRouter): probs = probs / routing_weights.sum(dim=-1, keepdim=True) num_experts = probs.size(-1) - capacity = self.get_capacity(inputs.shape) + num_tokens = inputs.size(0) + capacity = self.get_capacity(num_tokens, num_experts, ep_group) top1_idx = torch.argmax(probs, dim=-1) mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)