From 5b6cf7cab024d78059d361a0d2737823ae5e5d42 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 8 Aug 2023 15:07:04 +0800 Subject: [PATCH] reformat code --- .pre-commit-config.yaml | 4 +- internlm/moe/experts.py | 19 +-- internlm/moe/sharded_moe.py | 252 +++++++++++++++++------------------- 3 files changed, 134 insertions(+), 141 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 19cd7c8..182486e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -49,5 +49,5 @@ repos: args: [ '--rcfile=.pylintrc', - '--disable=C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203' - ] \ No newline at end of file + '--disable=C0330, C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203' + ] diff --git a/internlm/moe/experts.py b/internlm/moe/experts.py index 3b7af2c..bf34666 100644 --- a/internlm/moe/experts.py +++ b/internlm/moe/experts.py @@ -10,21 +10,24 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py # DeepSpeed Team +from typing import Union, cast + import torch -import copy from torch.nn import Module, ModuleList -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast + class Experts(torch.nn.Module): + """ + Local Experts. + """ def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1): - super(Experts, self).__init__() + super().__init__() - # TODO: We can not deepcopy FeedForward since it contains a process_group in submodules + # TODO: We can not deepcopy FeedForward since it contains a process_group in submodules # self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)]) - - if type(experts) == ModuleList: + if isinstance(experts, ModuleList): self.experts = cast(ModuleList, experts) else: self.experts = ModuleList([experts]) @@ -33,7 +36,7 @@ class Experts(torch.nn.Module): # TODO: revisit allreduce for moe.gate... for expert in self.experts: # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) - for name, param in expert.named_parameters(): + for _, param in expert.named_parameters(): param.all_reduce = False def forward(self, inputs): @@ -41,7 +44,7 @@ class Experts(torch.nn.Module): expert_outputs = [] for chunk, expert in zip(chunks, self.experts): out = expert(chunk) - if type(out) is tuple: + if isinstance(out, tuple): out = out[0] # Ignore the bias term for now expert_outputs += [out] diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 01daecc..7af036c 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -1,14 +1,3 @@ -import torch.distributed as dist - -from internlm.utils.logger import get_logger -from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.core.context import global_context as gpc -from internlm.core.context import ParallelMode - - -# global llm logger -logger = get_logger(__file__) - """ The file has been adapted from the following files: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py @@ -22,13 +11,19 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py # DeepSpeed Team - -from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple import torch +import torch.distributed as dist +import torch.nn.functional as F from torch import Tensor from torch.nn import Module -import torch.nn.functional as F + +from internlm.utils.logger import get_logger +from internlm.utils.megatron_timers import megatron_timer as timer + +# global llm logger +logger = get_logger(__file__) if TYPE_CHECKING: Base = Module[Tensor] @@ -57,9 +52,9 @@ def multiplicative_jitter(x, device: torch.device, epsilon=1e-2): return x uniform = uniform_map.get(device) if uniform is None: - uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device), - high=torch.tensor(1.0 + epsilon, - device=device)).rsample # type: ignore + uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(1.0 - epsilon, device=device), high=torch.tensor(1.0 + epsilon, device=device) + ).rsample # type: ignore uniform_map[device] = uniform return x * uniform(x.shape) @@ -73,23 +68,28 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: gumbel_map[device] = gumbel return gumbel(shape) + # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity # See https://arxiv.org/pdf/2006.16668.pdf for details. # Based on https://github.com/pytorch/pytorch/pull/40762 class _AllToAll(torch.autograd.Function): + """ + All to all communication + """ @staticmethod def forward( - ctx: Any, - # TODO: replace with DS process group - group: torch.distributed.ProcessGroup, - input: Tensor) -> Tensor: # type: ignore + ctx: Any, + # TODO: replace with DS process group + group: torch.distributed.ProcessGroup, + inputs: Tensor, + ) -> Tensor: # type: ignore ctx.group = group - input = input.contiguous() - output = torch.empty_like(input) - dist.all_to_all_single(output, input, group=group) + inputs = inputs.contiguous() + output = torch.empty_like(inputs) + dist.all_to_all_single(output, inputs, group=group) return output @staticmethod @@ -107,26 +107,26 @@ USE_EINSUM = True def einsum(rule, a, b): if USE_EINSUM: return torch.einsum(rule, a, b) - elif rule == 's,se->se': - ## [1, s] * [s, e] - return a.reshape(a.shape[0], -1) * b - elif rule == 'se,sc->sec': - ## [s,e,1] * [s,1,c] - return a.unsqueeze(2) * b.unsqueeze(1) - elif rule == 'se,se->s': - ## [s,1,e] * [s,e,1] - return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) - elif rule == 'sec,sm->ecm': - ## [e*c, s] * [s, m] + elif rule == "s,se->se": + # [1, s] * [s, e] + return a.reshape(a.shape[0], -1) * b + elif rule == "se,sc->sec": + # [s,e,1] * [s,1,c] + return a.unsqueeze(2) * b.unsqueeze(1) + elif rule == "se,se->s": + # [s,1,e] * [s,e,1] + return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) + elif rule == "sec,sm->ecm": + # [e*c, s] * [s, m] s = a.shape[0] e = a.shape[1] c = a.shape[2] m = b.shape[1] - return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) - elif rule == 'sec,ecm->sm': - ## [s, e*c] * [e*c, m] - return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) - elif rule == 'ks,ksm->sm': + return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) + elif rule == "sec,ecm->sm": + # [s, e*c] * [e*c, m] + return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) + elif rule == "ks,ksm->sm": k = b.shape[0] s = b.shape[1] m = b.shape[2] @@ -172,16 +172,17 @@ def _one_hot_to_float(x, num_classes): return F.one_hot(x, num_classes=num_classes).float() -def top1gating(logits: Tensor, - capacity_factor: float, - min_capacity: int, - used_token: Tensor = None, - noisy_gate_policy: Optional[str] = None, - drop_tokens: bool = True, - use_rts: bool = True, - use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +def top1gating( + logits: Tensor, + capacity_factor: float, + min_capacity: int, + used_token: Tensor = None, + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top1Gating on logits.""" - if noisy_gate_policy == 'RSample': + if noisy_gate_policy == "RSample": logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) # everything is in fp32 in this function gates = F.softmax(logits, dim=1) @@ -190,7 +191,7 @@ def top1gating(logits: Tensor, # Create a mask for 1st's expert per token # noisy gating - indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1) + indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == "RSample" else gates, dim=1) num_experts = int(gates.shape[1]) mask1 = F.one_hot(indices1_s, num_classes=num_experts) @@ -199,7 +200,7 @@ def top1gating(logits: Tensor, mask1 = einsum("s,se->se", used_token, mask1) # gating decisions - exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + exp_counts = torch.sum(mask1, dim=0).detach().to("cpu") # if we don't want to drop any tokens if not drop_tokens: @@ -216,42 +217,28 @@ def top1gating(logits: Tensor, if use_rts: uniform = exp_selection_uniform_map.get(logits.device) if uniform is None: - uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device), - high=torch.tensor(1.0, device=logits.device)).rsample + uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(0.0, device=logits.device), high=torch.tensor(1.0, device=logits.device) + ).rsample exp_selection_uniform_map[logits.device] = uniform mask1_rand = mask1 * uniform(mask1.shape) else: mask1_rand = mask1 - assert logits.shape[ - 0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." + assert ( + logits.shape[0] >= min_capacity + ), """No. of tokens (batch-size) should be greater than min_capacity. + Either set min_capacity to 0 or increase your batch size.""" - top_idx = _top_idx(mask1_rand, capacity) #@wenwen: token index + top_idx = _top_idx(mask1_rand, capacity) # @wenwen: token index new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) mask1 = new_mask1 - if use_tutel: - # Tutel doesn't support index values masked with zero - # so we need to replace masked indices with -1 - indices_mask = mask1.sum(dim=1) * num_experts - 1 - indices1_s = torch.min(indices1_s, indices_mask) - # Compute locations in capacity buffer - - locations1 = torch.cumsum(mask1, dim=0) - 1 - if use_tutel: - gates1_s = (gates * mask1).sum(dim=1) - locations1_s = torch.sum(locations1 * mask1, dim=1) - return l_aux, capacity, num_experts, [ - indices1_s, - ], [ - locations1_s, - ], [ - gates1_s, - ], exp_counts + locations1 = torch.cumsum(mask1, dim=0) - 1 # Store the capacity location for each token locations1_s = torch.sum(locations1 * mask1, dim=1) @@ -295,7 +282,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup locations2 += torch.sum(mask1, dim=0, keepdim=True) # gating decisions - exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + exp_counts = torch.sum(mask1, dim=0).detach().to("cpu") # Compute l_aux me = torch.mean(gates, dim=0) @@ -352,21 +339,23 @@ class TopKGate(Module): wg: torch.nn.Linear - def __init__(self, - model_dim: int, - num_experts: int, - k: int = 1, - capacity_factor: float = 1.0, - eval_capacity_factor: float = 1.0, - min_capacity: int = 8, - noisy_gate_policy: Optional[str] = None, - drop_tokens: bool = True, - use_rts: bool = True) -> None: + def __init__( + self, + model_dim: int, + num_experts: int, + k: int = 1, + capacity_factor: float = 1.0, + eval_capacity_factor: float = 1.0, + min_capacity: int = 8, + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True, + ) -> None: super().__init__() # Only top-1 and top-2 are supported at the moment. - if k != 1 and k != 2: - raise ValueError('Only top-1 and top-2 gatings are supported.') + if k not in (1, 2): + raise ValueError("Only top-1 and top-2 gatings are supported.") self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() self.k = k self.capacity_factor = capacity_factor @@ -378,34 +367,40 @@ class TopKGate(Module): self.drop_tokens = drop_tokens self.use_rts = use_rts - def forward(self, - input: torch.Tensor, - used_token: torch.Tensor = None, - use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore + def forward( + self, inputs: torch.Tensor, used_token: torch.Tensor = None + ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore if self.wall_clock_breakdown: - timer('TopKGate').start() + timer("TopKGate").start() if self.wg.weight.dtype != torch.float32: self.wg = self.wg.float() - input_fp32 = input.float() + inputs_fp32 = inputs.float() # input jittering - if self.noisy_gate_policy == 'Jitter' and self.training: - input_fp32 = multiplicative_jitter(input_fp32, device=input.device) - logits = self.wg(input_fp32) + if self.noisy_gate_policy == "Jitter" and self.training: + inputs_fp32 = multiplicative_jitter(inputs_fp32, device=inputs.device) + logits = self.wg(inputs_fp32) if self.k == 1: - gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, - self.min_capacity, used_token, self.noisy_gate_policy if self.training else None, - self.drop_tokens, self.use_rts, use_tutel) + gate_output = top1gating( + logits, + self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity, + used_token, + self.noisy_gate_policy if self.training else None, + self.drop_tokens, + self.use_rts, + ) else: - gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, - self.min_capacity) + gate_output = top2gating( + logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity + ) if self.wall_clock_breakdown: - timer('TopKGate').stop() - self.gate_time = timer('TopKGate').elapsed(reset=False) + timer("TopKGate").stop() + self.gate_time = timer("TopKGate").elapsed(reset=False) return gate_output @@ -416,7 +411,7 @@ class MOELayer(Base): gate = TopKGate(model_dim, num_experts) moe = MOELayer(gate, expert) - output = moe(input) + output = moe(inputs) l_aux = moe.l_aux .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf @@ -428,12 +423,7 @@ class MOELayer(Base): expert network """ - def __init__(self, - gate: Module, - experts: Module, - ep_group, - ep_size, - num_local_experts: int) -> None: + def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None: super().__init__() self.gate = gate self.experts = experts @@ -445,59 +435,59 @@ class MOELayer(Base): self.time_moe = 0.0 self.wall_clock_breakdown = False - def _set_ep_group(self, ep_group): self.ep_group = ep_group - def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: + def forward(self, *inputs: Tensor) -> Tensor: if self.wall_clock_breakdown: - timer('moe').start() + timer("moe").start() # Implement Algorithm 2 from GShard paper. - d_model = input[0].shape[-1] + d_model = inputs[0].shape[-1] # Initial implementation -> Reshape into S tokens by dropping sequence dimension. # Reshape into G groups so that each group can distribute tokens equally # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 - reshaped_input = input[0].reshape(-1, d_model) + reshaped_inputs = inputs[0].reshape(-1, d_model) - self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1]) - dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) ## TODO: heavy memory usage due to long sequence length + self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1]) + dispatched_inputs = einsum( + "sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs + ) # TODO: heavy memory usage due to long sequence length if self.wall_clock_breakdown: - timer('falltoall').start() + timer("falltoall").start() - - dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) + dispatched_inputs = _AllToAll.apply(self.ep_group, dispatched_inputs) if self.wall_clock_breakdown: - timer('falltoall').stop() - self.time_falltoall = timer('falltoall').elapsed(reset=False) + timer("falltoall").stop() + self.time_falltoall = timer("falltoall").elapsed(reset=False) # Re-shape after all-to-all: ecm -> gecm - dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model) + dispatched_inputs = dispatched_inputs.reshape(self.ep_size, self.num_local_experts, -1, d_model) - expert_output = self.experts(dispatched_input) + expert_output = self.experts(dispatched_inputs) if self.wall_clock_breakdown: - timer('salltoall').start() + timer("salltoall").start() expert_output = _AllToAll.apply(self.ep_group, expert_output) if self.wall_clock_breakdown: - timer('salltoall').stop() - self.time_salltoall = timer('salltoall').elapsed(reset=False) + timer("salltoall").stop() + self.time_salltoall = timer("salltoall").elapsed(reset=False) # Re-shape back: gecm -> ecm expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) - combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output) + combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output) - a = combined_output.reshape(input[0].shape) + a = combined_output.reshape(inputs[0].shape) if self.wall_clock_breakdown: - timer('moe').stop() - self.time_moe = timer('moe').elapsed(reset=False) + timer("moe").stop() + self.time_moe = timer("moe").elapsed(reset=False) return a