ColossalAI/colossalai/shardformer/layer/utils.py

496 lines
18 KiB
Python
Raw Normal View History

from contextlib import contextmanager
from typing import List, Optional, Union
import torch
import torch.distributed as dist
from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size
from colossalai.accelerator import get_accelerator
class SeqParallelUtils:
@staticmethod
def marked_as_sp_partial_derived_param(param):
"""
Mark a parameter as partially derived in sequence parallelism.
Args:
param: The parameter to mark as partially derived.
"""
setattr(param, "partial_derived", True)
@staticmethod
def is_sp_partial_derived_param(param):
"""
Check if a parameter is marked as partially derived in sequence parallelism.
Args:
param: The parameter to check.
Returns:
bool: True if the parameter is marked as partially derived, False otherwise.
"""
return getattr(param, "partial_derived", False)
@staticmethod
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
def allreduce_partial_data_grad(
process_group: ProcessGroup,
model: nn.Module = None,
grads: List[torch.Tensor] = None,
):
"""
Allreduce partial derived gradients across the specified process group.
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
Args:
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
process_group (ProcessGroup): The process group for gradient synchronization.
model (nn.Module): The model from which gradients will be synchronized.
grads (List[torch.Tensor]): The list of gradients to be synchronized.
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
only_sp_partial (bool): Whether handle all the parameters or only parameters marked as partial derived.
Raises:
AssertionError: If both `model` and `grads` are provided or neither is provided.
"""
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
# Get the size of the process group, which determines whether synchronization is needed.
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
group_size = get_world_size(process_group) if process_group is not None else 1
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
if group_size == 1:
# If the process group size is 1, no synchronization is required.
return
if model is not None:
# If `model` is provided, extract partial derived gradients from the model's parameters.
grads = []
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
for p in model.parameters():
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
if p.grad is not None:
if SeqParallelUtils.is_sp_partial_derived_param(p):
grads.append(p.grad.data)
# Flatten and reduce the gradients using the specified process group.
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
if len(grads) == 0:
return
coalesced = _flatten_dense_tensors(grads)
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
# Unflatten the synchronized gradients and update the model's gradients.
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
else:
# If `grads` are provided explicitly, synchronize those gradients directly.
coalesced = _flatten_dense_tensors(grads)
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
class Randomizer:
"""
Randomizer enables the program to be executed under a different seed within the context.
Example:
```python
randomizer = Randomizer(seed=1024)
with randomizer.fork():
# do something here with seed 1024
do_something()
```
Args:
seed (int): The random seed to set.
enable_cpu (bool): fork the CPU RNG state as well.
with_index (bool): whether to use the index of the randomizer.
"""
_INDEX = 0
def __init__(self, seed: int):
self.seed = seed
# Handle device rng state
# 1. get the current rng state
# 2. set the seed and store the rng state
# 3. recover the original rng state
device_original_rng_state = get_accelerator().get_rng_state()
get_accelerator().manual_seed(seed)
self.device_rng_state = get_accelerator().get_rng_state()
get_accelerator().set_rng_state(device_original_rng_state)
# to the same for cpu rng state
cpu_original_rng_state = torch.get_rng_state()
torch.manual_seed(seed)
self.cpu_rng_state = torch.get_rng_state()
torch.set_rng_state(cpu_original_rng_state)
def _set_device_rng_state(self, rng_state):
get_accelerator().set_rng_state(rng_state)
def _get_device_rng_state(self):
current_state = get_accelerator().get_rng_state()
return current_state
def _set_cpu_rng_state(self, rng_state):
torch.set_rng_state(rng_state)
def _get_cpu_rng_state(self):
current_state = torch.get_rng_state()
return current_state
@contextmanager
def fork_rng(self, enable_cpu: bool = False):
"""
This is a context manager to change the dropout state and recover the original state.
Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try:
current_device_rng_state = self._get_device_rng_state()
self._set_device_rng_state(self.device_rng_state)
if enable_cpu:
current_cpu_rng_state = self._get_cpu_rng_state()
self._set_cpu_rng_state(self.cpu_rng_state)
yield
finally:
self.device_rng_state = self._get_device_rng_state()
self._set_device_rng_state(current_device_rng_state)
if enable_cpu:
self.cpu_rng_state = self._get_cpu_rng_state()
self._set_cpu_rng_state(current_cpu_rng_state)
@staticmethod
def index():
"""
Return the index of the randomizer. The index is useful when the user wants
to introduce some randomness in the program.
Note:
The index will increment by one each time this method is called.
Example:
```python
# assume we need a randomizer to init the weight of different layers
# we can use the index of the randomizer to do so that
# each layer has its own randomizer with a different seed
base_seed = torch.random.initial_seed()
seed = base_seed + Randomizer.index()
randomizer = Randomizer(seed)
with randomizer.fork():
init_weights()
```
"""
idx = Randomizer._INDEX
return idx
@staticmethod
def increment_index():
"""
Increment the index of the randomizer by one.
"""
Randomizer._INDEX += 1
@staticmethod
def reset_index():
"""
Reset the index to zero.
"""
Randomizer._INDEX = 0
@staticmethod
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
"""
Return whether the randomizer index is synchronized across processes.
"""
index = Randomizer.index()
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
dist.all_gather(gathered_index, index_tensor, process_group)
# make sure all the gathered index are the same
for i in range(1, dist.get_world_size(process_group)):
if gathered_index[i] != gathered_index[0]:
return False
return True
@staticmethod
def synchronize_index(process_group: ProcessGroup = None):
"""
All gather the index and pick the largest value.
"""
index = Randomizer.index()
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
dist.all_gather(gathered_index, index_tensor, process_group)
# pick the largest index
for i in range(1, dist.get_world_size(process_group)):
if gathered_index[i] > index_tensor:
index_tensor = gathered_index[i]
# set the index
Randomizer._INDEX = index_tensor.item()
def create_randomizer_with_offset(
seed: int, process_group: ProcessGroup = None, offset_by_rank: bool = True, offset_by_index: bool = True
):
"""
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
Args:
seed (int): The base random seed to set.
process_group (ProcessGroup): the process group to get the rank from.
offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True.
offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True.
Returns:
Randomizer: the randomizer with offset.
"""
base_seed = seed
if offset_by_rank and dist.is_initialized():
rank = dist.get_rank(process_group)
base_seed += rank
if offset_by_index:
# check if the randomizer index is synchronized
is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group)
assert is_synchronized, (
"We detect that the randomizer index is not synchronized across processes."
"This is not allowed when we want to create a randomizer with offset by index."
"Please call Randomizer.synchronize_index() first."
)
base_seed += Randomizer.index()
Randomizer.increment_index()
return Randomizer(seed=base_seed)
def split_batch_zigzag(
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Split the input sequence batch . Naively spliting the attention mask in the causal setting
will result in the preceding ranks having much less workload.
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
Args:
batch (List[torch.Tensor] or Tensor): The input tensor(s) to split.
sp_group (ProcessGroup): The process group for sequence parallelism.
seq_dim (int): The sequence dimension to split.
is_label (bool): If True, mask and shift the tensor for next token prediction.
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
[Feature] Split cross-entropy computation in SP (#5959) * halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-09-10 04:06:50 +00:00
if sp_size == 1:
return batch
if isinstance(batch, torch.Tensor):
batch = [batch]
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
if sp_size > 1:
for idx, tensor in enumerate(batch):
assert (
tensor.shape[seq_dim] // (sp_size * 2) > 1 and tensor.shape[seq_dim] % (sp_size * 2) == 0
), f"Bro, the seq length {tensor.shape[seq_dim]} for tensor {idx} can't be split by {sp_size * 2}!"
if is_label:
assert tensor.dim() == 2, "Label shape should be (B, Seqlen)"
tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1)
tensor = tensor.view(
*tensor.shape[:seq_dim],
2 * sp_size,
tensor.shape[seq_dim] // (2 * sp_size),
*tensor.shape[seq_dim + 1 :],
)
indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device)
tensor = tensor.index_select(seq_dim, indices).contiguous()
# (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...)
batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :])
if len(batch) == 1:
return batch[0]
return batch
def split_varlen_zigzag(
batch: Union[List[torch.Tensor], torch.Tensor],
cu_seqlens: torch.Tensor,
sp_group: ProcessGroup,
max_seqlen: int = 0,
is_batched_seq: bool = False,
is_label: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Split a packed seq/batch of padded sequences in a Zigzag fashion.
Different from split_batch_zigzag, inputs here have variable sequence lengths.
Args:
batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq,
where T is the total number of tokens.
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
sp_group (ProcessGroup): The process group for sequence parallelism.
max_seqlen (int): The maximum sequence length in the batch before splitting.
is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len.
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
Returns:
batch (List[torch.Tensor]): Packed sequences of shape (T, ..)
or (B, max_seqlen // sp_size, ...) if is_batched_seq
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
[Feature] Split cross-entropy computation in SP (#5959) * halfway * fix cross-PP-stage position id length diff bug * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * adapt chatglm, command-R, qwen * debug * halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements * add comments * q1 index only once * remove events to simplify stream sync * simplify forward/backward logic * 2d ring forward passed * 2d ring backward passed * fixes * fix ring attn loss * 2D ring backward + llama passed * merge * update logger * fix typo * rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * remove typos * fixes * support GPT --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-09-10 04:06:50 +00:00
if sp_size == 1:
return batch
if is_batched_seq:
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
if isinstance(batch, torch.Tensor):
batch = [batch]
# seq: (B, Sq, h, n)
# seq = seq[:, :rank * (seqlen // sp_size), ...]
for i, packed_seq in enumerate(batch):
device = packed_seq.device
dtype = packed_seq.dtype
if is_batched_seq:
assert max_seqlen % (sp_size * 2) == 0
# Recreate a padded tensor with the new max seqlen
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
local_seq = torch.zeros(shape, dtype=dtype, device=device)
else:
total_seqlen = cu_seqlens[-1]
assert (
total_seqlen % (2 * sp_size) == 0
), f"total_seqlen {total_seqlen} must be divisible by 2 * sp_size = {2 * sp_size}"
local_seq = []
for j in range(len(cu_seqlens) - 1):
start, end = cu_seqlens[j], cu_seqlens[j + 1]
seqlen = end - start
assert (
seqlen % (2 * sp_size) == 0
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
if is_batched_seq:
seq = packed_seq[j][:seqlen]
if is_label:
# Shift one position to the right for next token prediction
seq = torch.cat([seq[1:], torch.tensor([-100], dtype=dtype, device=device)])
seq = seq.chunk(2 * sp_size, dim=0)
half = seqlen // sp_size // 2
local_seq[j][:half] = seq[sp_rank]
local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank]
else:
seq = packed_seq[start:end]
if is_label:
seq = torch.cat(seq[1:], torch.tensor([-100], dtype=dtype, device=device))
seq = seq.chunk(sp_size * 2)
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
if is_batched_seq:
batch[i] = local_seq.contiguous()
else:
batch[i] = torch.cat(local_seq, dim=0)
if len(batch) == 1:
batch = batch[0]
return batch
def is_share_sp_tp(sp_mode: str):
"""sp_mode "ring" and "split_gather" use the TP group as SP group
to split both the vocab and sequence, so we must gather the sequence
to correctly get logits at each positions.
"""
return sp_mode in ["ring", "split_gather"]
class RingComm:
def __init__(self, process_group: dist.ProcessGroup):
self._process_group = process_group
self._ops = []
self.rank = dist.get_rank(self._process_group)
self.world_size = dist.get_world_size(self._process_group)
self._reqs = []
self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
def send_recv(
self,
send_tensor: torch.Tensor,
recv_tensor: Optional[torch.Tensor] = None,
commit: bool = True,
) -> torch.Tensor:
if recv_tensor is None:
res = torch.empty_like(send_tensor)
else:
res = recv_tensor
# looks like batch_isend_irecv doesn't deadlock even
# when we don't swap send recv ops based on rank
send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.extend([send_op, recv_op])
if commit:
self._reqs = dist.batch_isend_irecv(self._ops)
return res
def commit(self):
assert len(self._ops) > 0, "No ops to commit"
self._reqs = dist.batch_isend_irecv(self._ops)
def wait(self):
assert len(self._reqs) > 0, "No requests to wait for"
for req in self._reqs:
req.wait()
self._reqs = []
self._ops = []
@torch.jit.script
def get_half_index(cu_seqlens, *, front: bool):
index = torch.zeros(cu_seqlens[-1], dtype=torch.bool, device=cu_seqlens.device)
for i in range(len(cu_seqlens) - 1):
start, end = cu_seqlens[i], cu_seqlens[i + 1]
if front:
end = (start + end) // 2
else:
start = (start + end) // 2
index[start:end] = True
return index