diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index c09c10308..9f1d504ad 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -8,6 +8,13 @@ body: attributes: value: > #### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new). +- type: checkboxes + attributes: + label: Is there an existing issue for this bug? + description: Please search [here](https://github.com/hpcaitech/ColossalAI/issues) to see if an open or closed issue already exists for the bug you have encountered. + options: + - label: I have searched the existing issues + required: true - type: textarea attributes: label: 🐛 Describe the bug diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 27ab7c76a..37f39ec95 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -140,7 +140,7 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v -e . + pip install -v -e . pip install -r requirements/requirements-test.txt - name: Store Colossal-AI Cache diff --git a/README.md b/README.md index 9e215df63..3157d74c9 100644 --- a/README.md +++ b/README.md @@ -418,7 +418,7 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt ## Installation Requirements: -- PyTorch >= 1.11 and PyTorch <= 2.1 +- PyTorch >= 2.1 - Python >= 3.7 - CUDA >= 11.0 - [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher) diff --git a/applications/Colossal-LLaMA/prepare_sft_dataset.py b/applications/Colossal-LLaMA/prepare_sft_dataset.py index be5f9bcca..a857d6c0c 100644 --- a/applications/Colossal-LLaMA/prepare_sft_dataset.py +++ b/applications/Colossal-LLaMA/prepare_sft_dataset.py @@ -10,7 +10,7 @@ import math import os from multiprocessing import cpu_count -from colossal_llama.dataset.conversation import default_conversation +from colossal_llama.dataset.conversation import LLaMA2_Conv from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft from datasets import dataset_dict, load_dataset from transformers import AddedToken, AutoTokenizer @@ -78,6 +78,7 @@ def main(): # Fix split issue: https://github.com/huggingface/transformers/issues/23833 if args.llama_version == 2: tokenizer.add_tokens(AddedToken("", normalized=False, special=True), special_tokens=True) + default_conversation = LLaMA2_Conv tokenizer.add_bos_token = False tokenizer.add_eos_token = False diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 97057481e..f4f9f7a50 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,7 +1,9 @@ import ctypes import random import warnings +from collections import defaultdict from contextlib import contextmanager +from copy import deepcopy from functools import partial from types import MethodType from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union @@ -24,6 +26,8 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.interface.optimizer import DistributedOptim +from colossalai.nn.optimizer import DistGaloreAwamW from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer @@ -735,7 +739,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # Get all working gradients and gradients to be synchronized. all_working_grads = _get_all_working_grads() grads_to_sync = _get_grads_to_sync(all_working_grads) - if self.require_grad_sync and grads_to_sync is not None: + if self._grad_store.require_grad_sync and grads_to_sync is not None: # Synchronize sequence parallelism gradients if required. SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) else: @@ -759,7 +763,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # Call the superclass backward method to compute gradients. super().backward(loss, retain_graph) - if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. self._sync_sp_grads() else: @@ -784,7 +788,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # Call the superclass backward_by_grad method to compute gradients. super().backward_by_grad(tensor, grad) - if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. self._sync_sp_grads() else: @@ -1171,6 +1175,15 @@ class HybridParallelPlugin(PipelinePluginBase): lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: param_info = get_param_info(optimizer) + + # TODO: Support Galore + ZeRO + zero_stage = self.zero_stage + zero_config = deepcopy(self.zero_config) + if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: + warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") + zero_config["partition_grad"] = False + zero_stage = 0 + if not isinstance(model, ModelWrapper): use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( self.dp_size == 1 @@ -1194,7 +1207,8 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy=self.custom_policy, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): - if self.zero_stage == 0: + if zero_stage == 0: + is_zero = False if self.precision in ["fp16", "bf16"]: optimizer = HybridParallelAMPOptimizer( optimizer, @@ -1218,11 +1232,11 @@ class HybridParallelPlugin(PipelinePluginBase): tp_process_group=self.tp_group, ) else: - zero_dp_size = dist.get_world_size(dp_group) - if zero_dp_size == 1: + is_zero = self.dp_size > 1 + if self.dp_size == 1: warnings.warn( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you are not intended to use cpu_offload, please consider set zero_stage=0." + "If you do not intend to use cpu_offload, please consider set zero_stage=0." ) assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." @@ -1236,11 +1250,19 @@ class HybridParallelPlugin(PipelinePluginBase): pp_process_group=self.pp_group, verbose=True, clip_grad_norm=self.max_norm, - **self.zero_config, + **zero_config, **self.amp_config, ) # inject update_master_params model.update_master_params = MethodType(optimizer.update_master_params, model) + + # Setup optimizers that require global states + optim = optimizer.optim + if isinstance(optim, DistributedOptim): + shard_to_param = optimizer.get_master_to_working_map() if is_zero else {} + padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int) + optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero) + return model, optimizer, criterion, dataloader, lr_scheduler def execute_pipeline( @@ -1272,7 +1294,7 @@ class HybridParallelPlugin(PipelinePluginBase): # run with gradients accumulation if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False ): return outputs diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index be75bebac..dfc743fe5 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -8,7 +8,10 @@ from types import MethodType from typing import Callable, Dict, Iterator, List, Optional, Tuple import torch +import torch.distributed +import torch.distributed as dist import torch.nn as nn +from torch.distributed.distributed_c10d import _get_default_group from torch.nn import Parameter from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler @@ -28,6 +31,8 @@ from colossalai.checkpoint_io.utils import ( sharded_optimizer_loading_epilogue, ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.interface.optimizer import DistributedOptim +from colossalai.nn.optimizer import DistGaloreAwamW from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.zero import LowLevelZeroOptimizer @@ -428,13 +433,31 @@ class LowLevelZeroPlugin(DPPluginBase): if not isinstance(model, ModelWrapper): model = LowLevelZeroModel(model, self.precision) + # TODO: Support Galore + ZeRO + zero_stage = self.stage + zero_optim_kwargs = {**self.zero_optim_kwargs} + dp_size = dist.get_world_size() + if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: + warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") + zero_optim_kwargs["partition_grad"] = False + zero_stage = 0 + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer( - optimizer, **self.zero_optim_kwargs, verbose=self.verbose + optimizer, **zero_optim_kwargs, verbose=self.verbose ) # inject update_master_params model.update_master_params = MethodType(optimizer.update_master_params, model) + # Setup optimizers that require global states + optim = optimizer.optim + is_zero = dp_size > 1 and zero_stage > 0 + dp_group = _get_default_group() # Use the whole world + if isinstance(optim, DistributedOptim): + shard_to_param = optimizer.get_master_to_working_map() + padding_map = optimizer.get_param_padding_map() + optim.setup_distributed(None, dp_group, shard_to_param, padding_map, is_zero) + return model, optimizer, criterion, dataloader, lr_scheduler def control_checkpoint_io(self) -> bool: diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 43095af50..fea4a23ba 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -38,7 +38,12 @@ class ProcessGroupMesh: def __init__(self, *size: int) -> None: assert dist.is_initialized(), "Please initialize torch.distributed first." - assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size." + world_size = dist.get_world_size() + prod_size = prod(size) + assert ( + prod_size == world_size + ), f"The product of the size({prod_size}) must be equal to the world size({world_size})." + self._shape = size self._rank = dist.get_rank() self._coord = ProcessGroupMesh.unravel(self._rank, self._shape) diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 3949590e8..171d88762 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -306,9 +306,8 @@ class DeviceMesh: # index means the local rank in the current axis # inner_tensor refers to the processes with the same local rank - if inner_tensor.numel() == 1: - # if the inner_tensor only has one element, it means that - # it already reaches the last axis + if inner_tensor.dim() == 0: + # if the inner_tensor already reaches the last axis, # we append its local_rank in the last axis to the index_list # and assign to the mapping # the value of the mapping is the the local rank at the indexed axis of the device mesh @@ -459,6 +458,7 @@ class DeviceMesh: # replace the local rank in the given dimension with the # local rank of the current process iterated + process_coordinates[dim] = _local_rank processes_in_the_same_process_group[dim].append(process_coordinates) diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index 95d11087b..6cd74b3b4 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -1,6 +1,7 @@ -from typing import Union +from typing import Dict, Optional, Union import torch +import torch.distributed as dist import torch.nn as nn from torch import Tensor from torch.optim import Optimizer @@ -133,3 +134,25 @@ class OptimizerWrapper: Unwrap the optimizer for checkpoint saving/loading. """ return self.optim + + +class DistributedOptim(Optimizer): + def setup_distributed( + self, + tp_group: Optional[dist.ProcessGroup] = None, + dp_group: Optional[dist.ProcessGroup] = None, + shard_to_working_param: Optional[Dict] = {}, + padding_map: Optional[Dict] = None, + is_zero: Optional[bool] = False, + ): + """Assign process groups for TP and ZeRO 2. + Arguments: + tp_group (dist.ProcessGroup): Tensor Parallel process group + dp_group (dist.ProcessGroup): ZeRO stage 2 process group + shard_to_working_param (Dict): ZeRO stage 2 feeds the optimizer a sharded param view to match grad shape. + This maps from id(view) to model params used in forward & backward. + padding_map (Dict): Per-param padding from ZeRO stage 2 + is_zero (bool): Whether to use ZeRO stage 2. + """ + + raise NotImplementedError("setup_distributed for TP/DP isn't supported by this optimizer yet!") diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py index 21d44d424..736ffc5e4 100644 --- a/colossalai/lazy/pretrained.py +++ b/colossalai/lazy/pretrained.py @@ -1,3 +1,4 @@ +import copy import os from typing import Callable, Optional, Union @@ -74,6 +75,24 @@ def new_from_pretrained( subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) + + kwargs.pop("state_dict", None) + kwargs.pop("from_tf", False) + kwargs.pop("from_flax", False) + kwargs.pop("output_loading_info", False) + kwargs.pop("trust_remote_code", None) + kwargs.pop("low_cpu_mem_usage", None) + kwargs.pop("device_map", None) + kwargs.pop("max_memory", None) + kwargs.pop("offload_folder", None) + kwargs.pop("offload_state_dict", False) + kwargs.pop("load_in_8bit", False) + kwargs.pop("load_in_4bit", False) + kwargs.pop("quantization_config", None) + kwargs.pop("adapter_kwargs", {}) + kwargs.pop("adapter_name", "default") + kwargs.pop("use_flash_attention_2", False) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) if len(kwargs) > 0: @@ -108,6 +127,10 @@ def new_from_pretrained( **kwargs, ) else: + config = copy.deepcopy(config) + kwarg_attn_imp = kwargs.pop("attn_implementation", None) + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: + config._attn_implementation = kwarg_attn_imp model_kwargs = kwargs if commit_hash is None: diff --git a/colossalai/legacy/inference/async_manager.py b/colossalai/legacy/inference/async_manager.py index 60440a792..526e0f632 100644 --- a/colossalai/legacy/inference/async_manager.py +++ b/colossalai/legacy/inference/async_manager.py @@ -55,14 +55,14 @@ class Async_DynamicBatchManager(DynamicBatchManager): self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch has_new_finished, outputs = self._prefill_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens = 0 else: if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) has_new_finished, outputs = self._decode_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens += 1 else: @@ -78,7 +78,7 @@ class Async_DynamicBatchManager(DynamicBatchManager): else: self.stats_tool.count_output_tokens(self.running_batch) has_new_finished, outputs = self._decode_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens += 1 if has_new_finished: diff --git a/colossalai/legacy/inference/manager.py b/colossalai/legacy/inference/manager.py index 9672a5014..050dc22b5 100644 --- a/colossalai/legacy/inference/manager.py +++ b/colossalai/legacy/inference/manager.py @@ -131,14 +131,14 @@ class DynamicBatchManager: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch yield from self._prefill_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) yield from self._decode_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens += 1 return else: @@ -154,7 +154,7 @@ class DynamicBatchManager: else: self.stats_tool.count_output_tokens(self.running_batch) yield from self._decode_batch(self.running_batch) - self._filter_runing_batch() + self._filter_running_batch() self.has_wait_tokens += 1 return @@ -243,7 +243,7 @@ class DynamicBatchManager: self._filter_batch(batch) yield from self._output_process(finished_reqs) - def _filter_runing_batch(self): + def _filter_running_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): self.running_batch = None diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index 26f152da2..c7261b1bc 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -1,9 +1,36 @@ +from galore_torch import GaLoreAdafactor, GaLoreAdamW + +from .came import CAME from .cpu_adam import CPUAdam +from .distributed_adafactor import DistributedAdaFactor +from .distributed_came import DistributedCAME +from .distributed_galore import DistGaloreAwamW +from .distributed_lamb import DistributedLamb from .fused_adam import FusedAdam from .fused_lamb import FusedLAMB from .fused_sgd import FusedSGD +from .galore import GaLoreAdamW8bit from .hybrid_adam import HybridAdam from .lamb import Lamb from .lars import Lars -__all__ = ["FusedLAMB", "FusedAdam", "FusedSGD", "Lamb", "Lars", "CPUAdam", "HybridAdam"] +from .adafactor import Adafactor # noqa + +__all__ = [ + "FusedLAMB", + "FusedAdam", + "FusedSGD", + "Lamb", + "Lars", + "CPUAdam", + "HybridAdam", + "DistributedLamb", + "DistGaloreAwamW", + "GaLoreAdamW", + "GaLoreAdafactor", + "GaLoreAdamW8bit", + "CAME", + "DistributedCAME", + "Adafactor", + "DistributedAdaFactor", +] diff --git a/colossalai/nn/optimizer/adafactor.py b/colossalai/nn/optimizer/adafactor.py new file mode 100644 index 000000000..22a6c8f4d --- /dev/null +++ b/colossalai/nn/optimizer/adafactor.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from torch.optim import Optimizer + +__all__ = ["Adafactor"] + + +# Adafactor +class Adafactor(Optimizer): + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + lr = None + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + """ + param_groups: Dict + { + "params":[weight, bias] + "lr" + "eps" + "clip_threshold" + "decay_rate" + "beta1" + "weight_decay" + "scale_parameter" + "relative_step" + "warmup_init" + } + """ + + for group in self.param_groups: + # update weight & bias + for p in group["params"]: + if p.grad is None: + continue + """ + # grad shape is same as weigh / bias + """ + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + """ + p is weight + state + {'step', + 'exp_avg_sq_row', + 'exp_avg_sq_col', + 'RMS' + } + + p is bias + state + {'step', + 'exp_avg_sq', + 'RMS' + } + """ + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], device=grad.device) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], device=grad.device) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"] + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"] + state["exp_avg_sq_col"] = state["exp_avg_sq_col"] + else: + state["exp_avg_sq"] = state["exp_avg_sq"] + + state["step"] += 1 + # state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + # Exponential average of row indexes + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + # Exponential average of columns indexes + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + # RMS + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p.add_(p, alpha=(-group["weight_decay"] * lr)) + p.add_(-update) + + return loss diff --git a/colossalai/nn/optimizer/came.py b/colossalai/nn/optimizer/came.py new file mode 100644 index 000000000..3a1a79dff --- /dev/null +++ b/colossalai/nn/optimizer/came.py @@ -0,0 +1,150 @@ +# Copied from https://github.com/yangluo7/CAME/blob/master/came_pytorch/CAME.py +import torch +import torch.optim + + +class CAME(torch.optim.Optimizer): + """Implements CAME algorithm. + This implementation is based on: + `CAME: Confidence-guided Adaptive Memory Efficient Optimization` + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constants for square gradient + and instability respectively (default: (1e-30, 1e-16)) + clip_threshold (float): threshold of root-mean-square of + final gradient update (default: 1.0) + betas (tuple[float, float, float]): coefficient used for computing running averages of + update, square gradient and instability (default: (0.9, 0.999, 0.9999))) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + """ + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-16), + clip_threshold=1.0, + betas=(0.9, 0.999, 0.9999), + weight_decay=0.0, + ): + assert lr > 0.0 + assert all([0.0 <= beta <= 1.0 for beta in betas]) + + defaults = dict( + lr=lr, + eps=eps, + clip_threshold=clip_threshold, + betas=betas, + weight_decay=weight_decay, + ) + super(CAME, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return False + + def _get_options(self, param_shape): + factored = len(param_shape) >= 2 + return factored + + def _rms(self, tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("CAME does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored = self._get_options(grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], dtype=p.dtype, device=p.device) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:], dtype=p.dtype, device=p.device + ) + + state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1], dtype=p.dtype, device=p.device) + state["exp_avg_res_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:], dtype=p.dtype, device=p.device + ) + else: + state["exp_avg_sq"] = torch.zeros_like(p) + + state["step"] += 1 + + update = (grad**2) + group["eps"][0] + + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(group["betas"][1]).add_(update.mean(dim=-1), alpha=1.0 - group["betas"][1]) + exp_avg_sq_col.mul_(group["betas"][1]).add_(update.mean(dim=-2), alpha=1.0 - group["betas"][1]) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1]) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + + exp_avg = state["exp_avg"] + exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0]) + + # Confidence-guided strategy + # Calculation of instability + res = (update - exp_avg) ** 2 + group["eps"][1] + + if factored: + exp_avg_res_row = state["exp_avg_res_row"] + exp_avg_res_col = state["exp_avg_res_col"] + exp_avg_res_row.mul_(group["betas"][2]).add_(res.mean(dim=-1), alpha=1.0 - group["betas"][2]) + exp_avg_res_col.mul_(group["betas"][2]).add_(res.mean(dim=-2), alpha=1.0 - group["betas"][2]) + + # Approximation of exponential moving average of instability + res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col) + update = res_approx.mul_(exp_avg) + else: + update = exp_avg.clone() + + if group["weight_decay"] != 0: + p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"]) + update.mul_(group["lr"]) + p.data.add_(-update) + + return loss diff --git a/colossalai/nn/optimizer/distributed_adafactor.py b/colossalai/nn/optimizer/distributed_adafactor.py new file mode 100644 index 000000000..1e5b7cb93 --- /dev/null +++ b/colossalai/nn/optimizer/distributed_adafactor.py @@ -0,0 +1,440 @@ +import math +from typing import Dict + +import torch +import torch.distributed as dist + +from colossalai.interface.optimizer import DistributedOptim +from colossalai.shardformer.layer._operation import _gather, _split +from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor + +# DistributedAdaFactor (with Tensor parallel and Zero stage 2) +__all__ = ["DistributedAdaFactor"] + + +class DistributedAdaFactor(DistributedOptim): + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + lr = None + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + self.tp_size = 1 + self.tp_group = None + self.dp_size = 1 + self.dp_group = None + self.shard_to_working_param = None # Dict{id:shape}, sample {id(param): torch.tensor} + self.use_zero = True + + self.param_is_dtensor_dict = {} # {id(p): True/False} + self.grad_shape_dict = {} # {id(p): master param shape} + self.factored_dict = {} # {id(p): True/False} + self.use_first_moment_dict = {} # {id(p): True/False} + self.shard_spec_dict = {} # {id(p): ShardSpec} + super().__init__(params, defaults) + + def setup_distributed( + self, + tp_group: dist.ProcessGroup = None, + dp_group: dist.ProcessGroup = None, + shard_to_working_param: Dict = {}, + padding_map=None, + use_zero: bool = True, + ) -> None: + """Setup process groups for TP and ZeRO 2. + Inject features to the Optimizer + + Args: + tp_group: The devices group for tensor parallel; + dp_group: The devices group for data parallel; + shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded. + This maps from id(view) to working params used in forward & backward. + padding_map: An empty interface placeholder; + use_zero: Whether or not to use zero; + + """ + self.tp_group = tp_group # "Expected row process group" + self.dp_group = dp_group + if self.tp_group is not None: + self.tp_size = dist.get_world_size(self.tp_group) + if self.dp_group is not None: + self.dp_size = dist.get_world_size(self.dp_group) + self.use_zero = use_zero + + self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {} + # grad is None, cause we dont setup now + for group in self.param_groups: + for p in group["params"]: + self.shard_to_working_param[id(p)] = self.shard_to_working_param.get( + id(p), p + ) # If not ZeRO, working param is master param + self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_working_param[id(p)]) + self.grad_shape_dict[id(p)] = self.shard_to_working_param.get(id(p)).shape + self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options( + group, self.grad_shape_dict[id(p)] + ) + if self.param_is_dtensor_dict[id(p)]: + self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_working_param[id(p)]) + else: + self.shard_spec_dict[id(p)] = None + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + """ + Determines whether the current param is factored + Args: + param_group : param group + param_shape : Original Shape of param + + """ + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor, param_is_dtensor, use_zero, tp_size, dp_size, tp_group, dp_group): + tensor_sum = tensor.pow(2).sum() + num_of_element = tensor.numel() + + if param_is_dtensor: + # reduce tensor_sum from tp_group + dist.all_reduce(tensor_sum, group=tp_group) + num_of_element = num_of_element * tp_size + if use_zero: + dist.all_reduce(tensor_sum, group=dp_group) + num_of_element = num_of_element * dp_size + rms = (tensor_sum / num_of_element).sqrt() + return rms + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + # approx_sq_grad for row parallel weight + @staticmethod + def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): + # row_meam = sq_row_meam + r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t): + if grad_shape[0] % self.dp_size != 0: + # gather update[flatten] along dp group then reshape to [H, W/tp] + update = _gather(input_=update, dim=-1, process_group=self.dp_group) + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H, W/tp] + grad = _gather(input_=grad, dim=-1, process_group=self.dp_group) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + else: + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + dist.all_reduce(exp_avg_sq_row, group=self.tp_group) + exp_avg_sq_row.div_(self.tp_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape + return update + + def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t): + if grad_shape[0] % self.dp_size != 0: + # gather update[flatten] along dp group then reshape to [H/tp, W] + update = _gather(input_=update, dim=-1, process_group=self.dp_group) + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H/tp, W] + grad = _gather(input_=grad, dim=-1, process_group=self.dp_group) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tp_group) + exp_avg_sq_col.div_(self.tp_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.dp_group) + else: + update = update_reshape + else: + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tp_group) + exp_avg_sq_col.div_(self.tp_size) + # gather row + exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tp_group) + sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) + update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape + return update + + def _base_factor(self, update, grad, state, grad_shape, beta2t): + if self.use_zero: + # only zero + if grad_shape[0] % self.dp_size != 0: + # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) + # row mean no change + # col mean need reduce and div + # gather update[flatten] along dp group then reshape to [H, W] + update = _gather(input_=update, dim=-1, process_group=self.dp_group) + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H, W] + grad = _gather(input_=grad, dim=-1, process_group=self.dp_group) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tp_group) + exp_avg_sq_col.div_(self.tp_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.dp_group) + else: + # no residual row + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] + grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] + exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp] + exp_avg_sq_col = state["exp_avg_sq_col"] # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tp_group) + exp_avg_sq_col.div_(self.tp_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + update = update_reshape.view(-1) + else: + # base factor; no tp, no dp + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + # Exponential average of row indexes + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + # Exponential average of columns indexes + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + return update + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization steps + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + """ + param_groups: Dict + { + "params":[weight, bias] + "lr" + "eps" + "clip_threshold" + "decay_rate" + "beta1" + "weight_decay" + "scale_parameter" + "relative_step" + "warmup_init" + } + """ + for group in self.param_groups: + # update weight & bias + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = self.grad_shape_dict[id(p)] + param_is_dtensor = self.param_is_dtensor_dict[id(p)] + if param_is_dtensor: + grad_shape = self.shard_to_working_param.get(id(p)).shape # tp shape (2 dim) + factored, use_first_moment = self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] + + shard_spec = self.shard_spec_dict[id(p)] + if len(state) == 0: + state["step"] = 0 + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p) + if factored: + if param_is_dtensor: + if shard_spec.sharding_sequence[0] == "R": # Col Parallel + if grad_shape[0] % self.dp_size != 0: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H] + else: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype + ) # [H/dp] + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W/TP] + + if shard_spec.sharding_sequence[-1] == "R": # Row Parallel + # Row indivisible shape situation + if grad_shape[0] % self.dp_size != 0: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H/tp] + else: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype + ) # [H/dp/tp] + + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] + else: + if self.use_zero: + if grad_shape[0] % self.dp_size != 0: + # save all exp_avg_sq_row [H] + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=grad.device, dtype=p.dtype + ) + else: + # exp_avg_sq_row [H // dp] + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.dp_size, device=grad.device, dtype=p.dtype + ) + else: + # exp_avg_sq_row [H] + state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) + # exp_avg_sq_col alaways [W] + state["exp_avg_sq_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype) + else: + state["exp_avg_sq"] = torch.zeros_like(p) + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"] + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"] + state["exp_avg_sq_col"] = state["exp_avg_sq_col"] + else: + state["exp_avg_sq"] = state["exp_avg_sq"] + + state["step"] += 1 + lr = self._get_lr(group, state) + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + + if factored: + if param_is_dtensor: + # ============================== + # First Dim is R, Last Dim is S{} means split dim -1 ---> + # Coloum Parallel ---> sq_row need Do (col) Reduce + # ============================== + if shard_spec.sharding_sequence[0] == "R": + update = self._col_parallel_factor(update, grad, state, grad_shape, beta2t) + # ============================== + # Last Dim is R, First Dim is S{} means split dim 0 ---> + # Row Parallel ---> sq_col need Do (row) Reduce + # ============================== + elif shard_spec.sharding_sequence[-1] == "R": + update = self._row_parallel_factor(update, grad, state, grad_shape, beta2t) + else: + update = self._base_factor(update, grad, state, grad_shape, beta2t) + else: + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + # # (Line No.8) RMS + rms = self._rms( + update, + param_is_dtensor, + self.use_zero, + self.tp_size, + self.dp_size, + self.tp_group, + self.dp_group, + ) + update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) + + update.mul_(lr) + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p.add_(p, alpha=(-group["weight_decay"] * lr)) + + p.add_(-update) + + return loss diff --git a/colossalai/nn/optimizer/distributed_came.py b/colossalai/nn/optimizer/distributed_came.py new file mode 100644 index 000000000..d93ec4982 --- /dev/null +++ b/colossalai/nn/optimizer/distributed_came.py @@ -0,0 +1,557 @@ +from typing import Dict + +import torch +import torch.distributed as dist + +from colossalai.interface.optimizer import DistributedOptim +from colossalai.shardformer.layer._operation import _gather, _split +from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor + + +class DistributedCAME(DistributedOptim): + """Implements CAME algorithm. + This implementation is based on: + `CAME: Confidence-guided Adaptive Memory Efficient Optimization` + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constants for square gradient + and instability respectively (default: (1e-30, 1e-16)) + clip_threshold (float): threshold of root-mean-square of + final gradient update (default: 1.0) + betas (tuple[float, float, float]): coefficient used for computing running averages of + update, square gradient and instability (default: (0.9, 0.999, 0.9999))) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + """ + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-16), + clip_threshold=1.0, + betas=(0.9, 0.999, 0.9999), + weight_decay=0.0, + ): + assert lr > 0.0 + assert all([0.0 <= beta <= 1.0 for beta in betas]) + + defaults = dict( + lr=lr, + eps=eps, + clip_threshold=clip_threshold, + betas=betas, + weight_decay=weight_decay, + ) + + self.tp_size = 1 + self.tp_group = None + self.dp_size = 1 + self.dp_group = None + self.shard_to_working_param = None # Dict{id:shape}, sample {id(param): torch.tensor} + self.use_zero = True + + self.param_is_dtensor_dict = {} # {id(p): True/False} + self.grad_shape_dict = {} # {id(p): master param shape} + self.factored_dict = {} # {id(p): True/False} + self.use_first_moment_dict = {} # {id(p): True/False} + self.shard_spec_dict = {} # {id(p): ShardSpec} + + super(DistributedCAME, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return False + + def setup_distributed( + self, + tp_group: dist.ProcessGroup = None, + dp_group: dist.ProcessGroup = None, + shard_to_working_param: Dict = {}, + padding_map=None, + use_zero: bool = True, + ) -> None: + """ + Inject features to the Optimizer + + Args: + tp_group: The devices group for tensor parallel; + dp_group: The devices group for data parallel; + shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded. + This maps from id(view) to working params used in forward & backward. + padding_map: Interface placeholder + use_zero: Whether or not to use zero; + + """ + self.tp_group = tp_group # "Expected row process group" + self.dp_group = dp_group + if self.tp_group is not None: + self.tp_size = dist.get_world_size(self.tp_group) + if self.dp_group is not None: + self.dp_size = dist.get_world_size(self.dp_group) + self.use_zero = use_zero + + self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {} + # grad is None, cause we dont setup now + for group in self.param_groups: + for p in group["params"]: + # w/o ZeRO: master param = working param + self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(id(p), p) + self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_working_param[id(p)]) + self.grad_shape_dict[id(p)] = self.shard_to_working_param[id(p)].shape + # Avoid row parallel lead H=1, then factored param is determined as not factored; + if self.param_is_dtensor_dict[id(p)]: + self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_working_param[id(p)]) + if self.shard_spec_dict[id(p)].sharding_sequence[0] == "R": + self.factored_dict[id(p)] = True + elif self.shard_spec_dict[id(p)].sharding_sequence[-1] == "R": + self.factored_dict[id(p)] = True + else: + self.factored_dict[id(p)] = self._get_options(self.grad_shape_dict[id(p)]) + + else: + self.shard_spec_dict[id(p)] = None + self.factored_dict[id(p)] = self._get_options(self.grad_shape_dict[id(p)]) + + @staticmethod + def _get_options(param_shape): + factored = len(param_shape) >= 2 + return factored + + @staticmethod + def _rms(tensor, param_is_dtensor, use_zero, tp_size, dp_size, tp_group, dp_group): + tensor_sum = tensor.pow(2).sum() + num_of_element = tensor.numel() + + if param_is_dtensor: + # reduce tensor_sum from tp_group + dist.all_reduce(tensor_sum, group=tp_group) + num_of_element = num_of_element * tp_size + if use_zero: + dist.all_reduce(tensor_sum, group=dp_group) + num_of_element = num_of_element * dp_size + rms = (tensor_sum / num_of_element).sqrt() + return rms + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + # approx_sq_grad for row parallel weight + @staticmethod + def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam): + r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def _col_parallel_factor(self, update, grad, state_row, state_col, grad_shape, beta2t): + if grad_shape[0] % self.dp_size != 0: + # gather update[flatten] along dp group then reshape to [H, W/tp] + update = _gather(input_=update, dim=-1, process_group=self.dp_group) + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H, W/tp] + grad = _gather(input_=grad, dim=-1, process_group=self.dp_group) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state_row # [H] + exp_avg_sq_col = state_col # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + else: + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state_row # [H] + exp_avg_sq_col = state_col # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + dist.all_reduce(exp_avg_sq_row, group=self.tp_group) + exp_avg_sq_row.div_(self.tp_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape + return update + + def _row_parallel_factor(self, update, grad, state_row, state_col, grad_shape, beta2t): + if grad_shape[0] % self.dp_size != 0: + # gather update[flatten] along dp group then reshape to [H/tp, W] + update = _gather(input_=update, dim=-1, process_group=self.dp_group) + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H/tp, W] + grad = _gather(input_=grad, dim=-1, process_group=self.dp_group) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state_row # [H] + exp_avg_sq_col = state_col # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tp_group) + exp_avg_sq_col.div_(self.tp_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.dp_group) + else: + update = update_reshape + else: + update_reshape = update.view(-1, grad_shape[1]) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state_row # [H] + exp_avg_sq_col = state_col # [W/tp] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tp_group) + exp_avg_sq_col.div_(self.tp_size) + # gather row + exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tp_group) + sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True) + update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam) + update_reshape.mul_(grad_reshape) + if self.use_zero: + update = update_reshape.view(-1) + else: + update = update_reshape + return update + + def _base_factor(self, update, grad, state_row, state_col, grad_shape, beta2t): + if self.use_zero: + # only zero + # [30522, 128], [2, 128] + if grad_shape[0] % self.dp_size != 0: + # view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) + # row mean no change + # col mean need reduce and div + # gather update[flatten] along dp group then reshape to [H, W] + update = _gather(input_=update, dim=-1, process_group=self.dp_group) + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) + # gather grad[flatten] along dp group then reshape to [H, W] + grad = _gather(input_=grad, dim=-1, process_group=self.dp_group) + grad_reshape = grad.view(-1, grad_shape[1]) + exp_avg_sq_row = state_row # [H/dp] + exp_avg_sq_col = state_col # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tp_group) + exp_avg_sq_col.div_(self.tp_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.dp_group) + else: + # no residual row + # view update to origin[tp] shape + update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W] + grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W] + exp_avg_sq_row = state_row # [H/dp] + exp_avg_sq_col = state_col # [W] + exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tp_group) + exp_avg_sq_col.div_(self.tp_size) + update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update_reshape.mul_(grad_reshape) + update = update_reshape.view(-1) + else: + # # base factor; no tp, no dp + exp_avg_sq_row = state_row # [H/dp] + exp_avg_sq_col = state_col # [W] + # Exponential average of row indexes + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + # Exponential average of columns indexes + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + return update + + # factor + def _base_res_factor(self, res, exp_avg, state_row, state_col, grad_shape, beta2t): + if self.use_zero: + # only zero + if grad_shape[0] % self.dp_size != 0: + # view res to origin shape res.view(grad_shape[0]//self.data_parallel_size , grad_shape[1]) + # row mean no change + # col mean need reduce and div + # gather res[flatten] along dp group then reshape to [H, W] + res = _gather(input_=res, dim=-1, process_group=self.dp_group) + # view res to origin[tp] shape + res_reshape = res.view(-1, grad_shape[1]) + # gather exp_avg[flatten] along dp group then reshape to [H, W] + exp_avg = _gather(input_=exp_avg, dim=-1, process_group=self.dp_group) + exp_avg_reshape = exp_avg.view(-1, grad_shape[1]) + exp_avg_sq_row = state_row # [H/dp] + exp_avg_sq_col = state_col # [W] + exp_avg_sq_row.mul_(beta2t).add_(res_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(res_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tp_group) + exp_avg_sq_col.div_(self.tp_size) + res_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + res_reshape.mul_(exp_avg_reshape) + res = _split(input_=res_reshape.view(-1), dim=-1, process_group=self.dp_group) + else: + # no residual row + # view res to origin[tp] shape + res_reshape = res.view(-1, grad_shape[1]) # [H/dp, W] + exp_avg_reshape = exp_avg.view(-1, grad_shape[1]) # [H/dp, W] + exp_avg_sq_row = state_row # [H/dp] + exp_avg_sq_col = state_col # [W] + exp_avg_sq_row.mul_(beta2t).add_(res_reshape.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(res_reshape.mean(dim=-2), alpha=(1.0 - beta2t)) + # reduce col + dist.all_reduce(exp_avg_sq_col, group=self.tp_group) + exp_avg_sq_col.div_(self.tp_size) + res_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + res_reshape.mul_(exp_avg_reshape) + res = res_reshape.view(-1) + else: + # # base factor; no tp, no dp + exp_avg_sq_row = state_row # [H/dp] + exp_avg_sq_col = state_col # [W] + # Exponential average of row indexes + exp_avg_sq_row.mul_(beta2t).add_(res.mean(dim=-1), alpha=(1.0 - beta2t)) + # Exponential average of columns indexes + exp_avg_sq_col.mul_(beta2t).add_(res.mean(dim=-2), alpha=(1.0 - beta2t)) + # Approximation of exponential moving average of square of gradient + res = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + res.mul_(exp_avg) + return res + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("CAME does not support sparse gradients.") + + state = self.state[p] + # Under zero the grad_shape is the original grad that is flattened and then cut (only one dimension) + grad_shape = grad.shape + grad_shape = self.grad_shape_dict[id(p)] + param_is_dtensor = self.param_is_dtensor_dict[id(p)] + if param_is_dtensor: + grad_shape = self.shard_to_working_param.get(id(p)).shape # tp shape (2 dim) + factored = self.factored_dict[id(p)] + shard_spec = self.shard_spec_dict[id(p)] + + # State Initialization + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p) + if factored: + if param_is_dtensor: + if shard_spec.sharding_sequence[0] == "R": # Col Parallel + if grad_shape[0] % self.dp_size != 0: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H] + state["exp_avg_res_row"] = torch.zeros( + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H] + else: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype + ) # [H/dp] + state["exp_avg_res_row"] = torch.zeros( + grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype + ) # [H/dp] + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W/TP] + state["exp_avg_res_col"] = torch.zeros( + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W/TP] + + if shard_spec.sharding_sequence[-1] == "R": # Row Parallel + # Row indivisible shape situation + if grad_shape[0] % self.dp_size != 0: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H/tp] + state["exp_avg_res_row"] = torch.zeros( + grad_shape[0], device=p.device, dtype=p.dtype + ) # [H/tp] + else: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype + ) # [H/dp/tp] + state["exp_avg_res_row"] = torch.zeros( + grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype + ) # [H/dp/tp] + + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] + state["exp_avg_res_col"] = torch.zeros( + grad_shape[1], device=p.device, dtype=p.dtype + ) # [W] + else: + if self.use_zero: + if grad_shape[0] % self.dp_size != 0: + # save all exp_avg_sq_row [H] + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0], device=grad.device, dtype=p.dtype + ) + state["exp_avg_res_row"] = torch.zeros( + grad_shape[0], device=grad.device, dtype=p.dtype + ) + else: + # exp_avg_sq_row [H // dp] + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[0] // self.dp_size, device=grad.device, dtype=p.dtype + ) + state["exp_avg_res_row"] = torch.zeros( + grad_shape[0] // self.dp_size, device=grad.device, dtype=p.dtype + ) + else: + # exp_avg_sq_row [H] + state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) + state["exp_avg_res_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype) + # exp_avg_sq_col alaways [W] + state["exp_avg_sq_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype) + state["exp_avg_res_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype) + else: + state["exp_avg_sq"] = torch.zeros_like(p) + state["RMS"] = 0 + else: + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"] + state["exp_avg_sq_col"] = state["exp_avg_sq_col"] + state["exp_avg_res_row"] = state["exp_avg_sq_row"] + state["exp_avg_res_col"] = state["exp_avg_sq_col"] + else: + state["exp_avg_sq"] = state["exp_avg_sq"] + + state["step"] += 1 + + update = (grad**2) + group["eps"][0] + if factored: + if param_is_dtensor: + # ============================== + # First Dim is R, Last Dim is S{} means split dim -1 ---> + # Coloum Parallel ---> sq_row need Do (col) Reduce + # ============================== + if shard_spec.sharding_sequence[0] == "R": + update = self._col_parallel_factor( + update, + grad, + state["exp_avg_sq_row"], + state["exp_avg_sq_col"], + grad_shape, + group["betas"][1], + ) + # ============================== + # Last Dim is R, First Dim is S{} means split dim 0 ---> + # Row Parallel ---> sq_col need Do (row) Reduce + # ============================== + elif shard_spec.sharding_sequence[-1] == "R": + update = self._row_parallel_factor( + update, + grad, + state["exp_avg_sq_row"], + state["exp_avg_sq_col"], + grad_shape, + group["betas"][1], + ) + else: + update = self._base_factor( + update, + grad, + state["exp_avg_sq_row"], + state["exp_avg_sq_col"], + grad_shape, + group["betas"][1], + ) + else: + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=(1.0 - group["betas"][1])) + update = exp_avg_sq.rsqrt().mul_(grad) + rms = self._rms( + update, + param_is_dtensor, + self.use_zero, + self.tp_size, + self.dp_size, + self.tp_group, + self.dp_group, + ) + + update.div_((rms / group["clip_threshold"]).clamp_(min=1.0)) + + exp_avg = state["exp_avg"] + exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0]) + # Confidence-guided strategy + # Calculation of instability + res = (update - exp_avg) ** 2 + group["eps"][1] + if factored: + if param_is_dtensor: + # ============================== + # First Dim is R, Last Dim is S{} means split dim -1 ---> + # Coloum Parallel ---> sq_row need Do (col) Reduce + # ============================== + if shard_spec.sharding_sequence[0] == "R": + update = self._col_parallel_factor( + res, + exp_avg, + state["exp_avg_res_row"], + state["exp_avg_res_col"], + grad_shape, + group["betas"][2], + ) + # ============================== + # Last Dim is R, First Dim is S{} means split dim 0 ---> + # Row Parallel ---> sq_col need Do (row) Reduce + # ============================== + elif shard_spec.sharding_sequence[-1] == "R": + update = self._row_parallel_factor( + res, + exp_avg, + state["exp_avg_res_row"], + state["exp_avg_res_col"], + grad_shape, + group["betas"][2], + ) + else: + update = self._base_res_factor( + res, + exp_avg, + state["exp_avg_res_row"], + state["exp_avg_res_col"], + grad_shape, + group["betas"][2], + ) + else: + update = exp_avg + + if group["weight_decay"] != 0: + p.add_(p, alpha=-group["weight_decay"] * group["lr"]) + update.mul_(group["lr"]) + p.add_(-update) + return loss diff --git a/colossalai/nn/optimizer/distributed_galore.py b/colossalai/nn/optimizer/distributed_galore.py new file mode 100644 index 000000000..3f42dd5b9 --- /dev/null +++ b/colossalai/nn/optimizer/distributed_galore.py @@ -0,0 +1,279 @@ +""" adapted from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/adamw8bit.py""" + +import warnings +from collections import defaultdict +from typing import Dict, Optional + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from bitsandbytes.optim.optimizer import Optimizer2State + +from colossalai.interface.optimizer import DistributedOptim +from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor + +from .galore import GaLoreProjector, make_low_rank_buffer + +__all__ = ["DistributedGalore"] +# Mark sharded dimension + + +class DistGaloreAwamW(DistributedOptim, Optimizer2State): + r"""Implements Galore, a optimizer-agonistic gradient compression technique on 8-bit AdamW. + It largely compresses gradient via low-rank projection and is claimed to be insensitive to hyperparams like lr. + Supports Tensor Parallel and ZeRO stage 1 and 2 via booster and plugin. + Proposed in `GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection` + https://arxiv.org/abs/2403.03507 + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-6) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01) + nbits: Number of bits for quantization optim states. Only 32 and 8 are supported. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not. + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + nbits=8, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + nbits, + None, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + self.tp_size = 1 + self.dp_size = 1 + self.is_dist = {} + proj_none = all(["rank" not in group for group in self.param_groups]) + if proj_none: + warnings.warn( + "Will not apply GaLore as rank isn't in any param group. If you forgot to, try get_galore_param_groups" + ) + + # Default from the paper + for group in self.param_groups: + if "rank" in group: + group["update_proj_gap"] = group.get("update_proj_gap", 200) + group["proj_type"] = group.get("proj_type", "std") + group["scale"] = group.get("scale", 0.25) + + def setup_distributed( + self, + tp_group: Optional[dist.ProcessGroup] = None, + dp_group: Optional[dist.ProcessGroup] = None, + shard_to_working_param: Optional[Dict] = {}, + padding_map: Optional[Dict] = defaultdict(int), + is_zero: Optional[bool] = False, + ): + """Setup process groups for TP and ZeRO 2. + Arguments: + tp_group (dist.ProcessGroup): Tensor Parallel process group + dp_group (dist.ProcessGroup): ZeRO 2 process group + shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded. + This maps from id(view) to working params used in forward & backward. + padding_map (Dict): Padding size of each param from ZeRO's param store. Required if ZeRO is used. + is_zero (bool): Whether to use ZeRO 2. + """ + assert dist.is_initialized(), "You forgot to initialized distributed backend..." + + self.tp_group = tp_group + self.dp_group = dp_group + if tp_group is not None: + self.tp_size = dist.get_world_size(tp_group) + if dp_group is not None: + self.dp_size = dist.get_world_size(dp_group) + + self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {} + self.is_zero = is_zero and self.dp_size > 1 + self.padding_map = padding_map if padding_map is not None else defaultdict(int) + if is_zero: + assert self.padding_map is not defaultdict( + int + ), "We can't do SVD without knowing ZeRO's per-param padding size" + self.distributed_on = self.tp_size > 0 or self.dp_size > 0 + + # Cache working param layout + self.shard_dim = {} + for group in self.param_groups: + for p in group["params"]: + # w/o ZeRO: master param = working param + self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(id(p), p) + if id(p) not in self.padding_map: + self.padding_map[id(p)] = 0 + + self.is_dist[id(p)] = is_distributed_tensor(self.shard_to_working_param[id(p)]) + if is_distributed_tensor(self.shard_to_working_param[id(p)]): + self.shard_dim[id(p)] = get_shard_dim_1d(self.shard_to_working_param[id(p)]) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self.initialized: + self.check_overrides() + self.to_gpu() + self.initialized = True + + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group["params"]): + if p.grad is None: + continue + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # GaLore Projection + if "rank" in group: + if "projector" not in state: + state["projector"] = GaLoreProjector( + group["rank"], + scale=group["scale"], + update_proj_gap=group["update_proj_gap"], + proj_type=group["proj_type"], + ) + # decoupled weight decay + if "weight_decay" in group and group["weight_decay"] > 0: + group["weight_decay_saved"] = group["weight_decay"] + group["weight_decay"] = 0 + + grad = p.grad + working_shape = list(self.shard_to_working_param[id(p)].shape) + padding = self.padding_map[id(p)] + + # All-gather grads for projection step + if self.distributed_on: + # Gather for ZeRO 1 & 2 implementation don't retain full grads + if self.is_zero: + # (m, n).flatten().chunk(dp_size) equals to (m / dp_size, n).flatten() + working_shape[0] //= self.dp_size + # Gather grads for projection + if state["step"] % group["update_proj_gap"] == 0: + all_grads = [ + torch.empty_like(grad, dtype=p.grad.dtype, device=p.grad.device) + for _ in range(self.dp_size) + ] + dist.all_gather(all_grads, grad, self.dp_group) + grad = torch.cat(all_grads) + # To working param shape + if padding > 0: + grad = grad[:-padding] + working_shape[0] *= self.dp_size + grad = grad.reshape(working_shape) # unflatten + + # Gather TP grads + if self.is_dist[id(p)] and state["step"] % group["update_proj_gap"] == 0: + all_grads = [ + torch.empty_like(grad, dtype=p.grad.dtype, device=p.grad.device) + for _ in range(self.tp_size) + ] + dist.all_gather(all_grads, grad.contiguous(), self.tp_group) + grad = torch.cat(all_grads, dim=self.shard_dim[id(p)]) + + # Compute SVD. Will use a subset of singular vectors when grads are sharded. + grad = state["projector"].project(grad, state["step"]) + + # Re-shard gathered grads after SVD + if self.distributed_on and state["step"] % group["update_proj_gap"] == 0: + # TP + if self.is_dist[id(p)]: + grad = grad.chunk(self.tp_size, dim=self.shard_dim[id(p)])[dist.get_rank(self.tp_group)] + # ZeRO + # TODO: this might not work with padding, e.g. (3, 3) with dp size 2 + # Need extra logic in ZeRO to pad nRows/nCols to be divisible by dp_size + if self.is_zero: + grad = grad.chunk(self.dp_size)[dist.get_rank(self.dp_group)] + grad = grad.contiguous() # avoid bitsandbytes update error + + working_shape = grad.shape + # To flattended master param shape + grad = self.to_master_shape(grad, padding) + make_low_rank_buffer(p, grad) + + if "state1" not in state: + self.init_state(group, p, gindex, pindex) + + self.prefetch_state(p) + self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() + + # Project Back to working param shape + if "rank" in group: + # Unpad + if self.is_zero: + if padding > 0: + p.data = p.data[:-padding] + p.data = p.data.reshape(working_shape) + + p.data = state["projector"].project_back(p.data) + # Re-flatten grads for ZeRO + p.data = self.to_master_shape(p.data, padding) + p.data = p.saved_data.add_(p.data) + + # apply decoupled weight decay + if "weight_decay_saved" in group: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay_saved"]) + group["weight_decay"] = group["weight_decay_saved"] + del group["weight_decay_saved"] + + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + return loss + + def to_master_shape(self, data, padding): + """Pad to master (optimizer) param shape""" + if not self.is_zero: + return data + data = data.view(-1) + if padding > 0: + data = F.pad(data, [0, padding]) + return data + + def __del__(self): + """Avoid buffer memory leak""" + for group in self.param_groups: + for p in group["params"]: + if hasattr(p, "saved_data"): + del p.saved_data diff --git a/colossalai/nn/optimizer/distributed_lamb.py b/colossalai/nn/optimizer/distributed_lamb.py new file mode 100644 index 000000000..c9ab8feab --- /dev/null +++ b/colossalai/nn/optimizer/distributed_lamb.py @@ -0,0 +1,181 @@ +# Disclaimer: Modified from https://github.com/NUS-HPC-AI-Lab/pytorch-lamb/blob/master/optim/lamb.py + + +from typing import Dict, Optional + +import torch +import torch.distributed as dist + +from colossalai.interface.optimizer import DistributedOptim +from colossalai.tensor.d_tensor import is_distributed_tensor + +__all__ = ["DistributedLamb"] + + +class DistributedLamb(DistributedOptim): + r"""Implements the Lamb algorithm, with extra support for ZeRO 2 and Tensor Parallel. + Proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + It's recommended to use this with HybridParallelPlugin/ZeRO plugin and booster, + which will take care of setup_distributed. + Example with 4 devices: + >>> optim = DistributedLamb(model.parameters(), lr=1e-3) + >>> proc_mesh = ProcessGroupMesh(tp_size, zero_size) + >>> tp_group = proc_mesh.get_group_along_axis(0) + >>> dp_group = proc_mesh.get_group_along_axis(1) + >>> optim.setup_distributed(tp_group, dp_group) + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0, + bias_correction=True, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + # self.setup_distributed(tp_group, dp_group) + self.shard_to_working_param = {} + self.tp_size = self.dp_size = 1 + self.is_zero = False + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) + super().__init__(params, defaults) + + def setup_distributed( + self, + tp_group: Optional[dist.ProcessGroup] = None, + dp_group: Optional[dist.ProcessGroup] = None, + shard_to_working_param: Optional[Dict] = {}, + padding_map=None, + is_zero: Optional[bool] = False, + ): + """Assign process groups for TP and ZeRO 2. + Arguments: + tp_group (dist.ProcessGroup): Tensor Parallel process group + dp_group (dist.ProcessGroup): ZeRO 2 process group + shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded. + This maps from id(view) to working params used in forward & backward. + padding_map: An empty interface placeholder + is_zero (bool): Whether to use ZeRO 2. + """ + self.tp_group = tp_group + self.dp_group = dp_group + if tp_group is not None: + self.tp_size = dist.get_world_size(tp_group) + if dp_group is not None: + self.dp_size = dist.get_world_size(dp_group) + + self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {} + self.is_zero = is_zero + self.is_dist = {} + # Cache parameter layout + for group in self.param_groups: + for p in group["params"]: + # w/o ZeRO: master param = working param + self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(id(p), p) + self.is_dist[p] = ( + is_distributed_tensor(p) + if self.dp_size <= 1 + else is_distributed_tensor(self.shard_to_working_param.get(id(p), None)) + ) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instad.") + + state = self.state[p] + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # m_t + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # v_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + scaled_lr = group["lr"] + if group["bias_correction"]: + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + # Apply debiasing to lr to avoid broadcast + scaled_lr *= (bias_correction2**0.5) / bias_correction1 + # exp_avg.div_(bias_correction1) + # exp_avg_sq.div_(bias_correction2) + + update = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) + if group["weight_decay"] != 0: + update.add_(p.data, alpha=group["weight_decay"]) + + # Compute global layer-wise trust ratio + if self.is_dist[p] or self.is_zero: + p_local = p + g_sum = (update**2).sum() + if self.dp_size > 1 and self.is_zero: + # ZeRO 2 doesn't shard param. Compute full param norm w/o communication. + dist.all_reduce(g_sum, group=self.dp_group) + p_local = self.shard_to_working_param[id(p)] + + w_sum = (p_local**2).sum() + sums = torch.stack([w_sum, g_sum]) + + # Get global l2 norms + if self.tp_size > 1: + dist.all_reduce(sums, group=self.tp_group) + w_norm, g_norm = sums.sqrt().chunk(2) + else: + # Fall back to vanilla Lamb + w_norm = torch.norm(p) + g_norm = torch.norm(update) + + trust_ratio = torch.where(w_norm > 0 and g_norm > 0, (w_norm / g_norm), 1.0).item() + + scaled_lr *= trust_ratio + p.data.add_(update, alpha=-scaled_lr) + + return loss diff --git a/colossalai/nn/optimizer/galore.py b/colossalai/nn/optimizer/galore.py new file mode 100644 index 000000000..f7556fe61 --- /dev/null +++ b/colossalai/nn/optimizer/galore.py @@ -0,0 +1,315 @@ +""" adapted from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/adamw8bit.py""" + +import warnings +from typing import List + +import torch +from bitsandbytes.optim.optimizer import Optimizer2State +from torch._C import _LinAlgError + + +def get_galore_param_groups( + model, weight_decay, rank=256, update_proj_gap=200, scale=0.25, proj_type="std" +) -> List[dict]: + """ + It's advised to use this instead of manually specifying which param groups + to apply GaLore on. + """ + galore_params = [] + non_galore = [] + no_decay_params = [] + no_decay = ["bias", "LayerNorm.weight"] + + for name, param in model.named_parameters(): + # Only make sense to do SVD on 2d gradient matrices + # e.g. nn.Linear, VocabEmbedding, etc. + if any(nd in name for nd in no_decay): + no_decay_params.append(param) + elif param.dim() == 2: + galore_params.append(param) + else: + non_galore.append(param) + + param_groups = [ + { + "params": galore_params, + "rank": rank, + "update_proj_gap": update_proj_gap, + "scale": scale, + "proj_type": proj_type, + "weight_decay": weight_decay, + }, + {"params": non_galore, "weight_decay": weight_decay}, + {"params": no_decay_params, "weight_decay": 0.0}, + ] + + return param_groups + + +def make_low_rank_buffer(p, grad): + """For compatibility with bitsandbytes's update_step, we need an empty low-rank + param update buffer to avoid mutating original params. + TODO: optimize by reusing the memory for p.grad? Need to modify bitsandbytes? + """ + p.saved_data = p.data.clone() + # p.data = grad.clone().to(p.data.dtype).to(p.data.device) + p.data = torch.zeros_like(grad, device=grad.device, dtype=grad.dtype) + # p.data.zero_() + p.grad = grad + + +class GaLoreProjector: + def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type="std"): + self.rank = rank + self.verbose = verbose + self.update_proj_gap = update_proj_gap + self.scale = scale + self.ortho_matrix = None + self.proj_type = proj_type + self.svd_type = None + + def project(self, full_rank_grad, iter): + dim = full_rank_grad.dim() + if dim != 2: + warnings.warn( + f"Warning: You shouldn't specify projection rank for {dim}D params in param_groups. Skipping SVD." + ) + return full_rank_grad + + m, n = full_rank_grad.shape # For ZeRO sharded grads + if self.proj_type == "std": + # Project the lower dim to minimize information loss + if self.svd_type is None: + self.svd_type = "right" if m >= n else "left" + # SVD step + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type=self.svd_type) + if self.svd_type == "right": + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()[:n]) + else: + low_rank_grad = torch.matmul(self.ortho_matrix.t()[:, :m], full_rank_grad) + + elif self.proj_type == "reverse_std": + if self.svd_type is None: + self.svd_type = "left" if m >= n else "right" + # SVD step + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type=self.svd_type) + + if self.svd_type == "left": + low_rank_grad = torch.matmul(self.ortho_matrix.t()[:, :m], full_rank_grad) + else: + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()[:n]) + return low_rank_grad + + def project_back(self, low_rank_grad): + if low_rank_grad.dim() != 2: + return + + m, n = low_rank_grad.shape + if self.svd_type == "right": + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix[:n]) + else: + full_rank_grad = torch.matmul(self.ortho_matrix[:, :m], low_rank_grad) + + return full_rank_grad * self.scale + + # svd decomposition + def get_orthogonal_matrix(self, weights, rank, type): + module_params = weights + + if module_params.data.dtype != torch.float: + float_data = False + original_type = module_params.data.dtype + original_device = module_params.data.device + matrix = module_params.data.float() + else: + float_data = True + matrix = module_params.data + + # TODO: redo SVD in the next step. + if matrix.isnan().any(): + print(f"{__file__}: skipping SVD due to NaN matrix") + return self.ortho_matrix + try: + U, s, Vh = torch.linalg.svd(matrix, full_matrices=False) + except _LinAlgError as e: + print(f"{__file__}: skipping SVD due to {e}") + return self.ortho_matrix + + # make the smaller matrix always to be orthogonal matrix + if type == "right": + B = Vh[:rank, :] + + if not float_data: + B = B.to(original_device).type(original_type) + return B + elif type == "left": + A = U[:, :rank] + if not float_data: + A = A.to(original_device).type(original_type) + return A + elif type == "full": + A = U[:, :rank] + B = Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + B = B.to(original_device).type(original_type) + return [A, B] + else: + raise ValueError("type should be left, right or full") + + +class GaLoreAdamW8bit(Optimizer2State): + r"""Implements Galore, a optimizer-agonistic gradient compression technique on 8-bit AdamW. + Proposed in `GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection`. It compresses + gradient via low-rank projection and is claimed to be insensitive to hyperparams like lr. + https://arxiv.org/abs/2403.03507 + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-6) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01) + nbits (int): The number of bits of optim states. Only 32 and 8 are supported. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not. + Example: + + """ + + def __init__( + self, + params, + lr=1e-2, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + nbits=8, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + nbits, + None, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + + proj_none = all(["rank" not in group for group in self.param_groups]) + if proj_none: + warnings.warn( + "Will not apply GaLore as no rank is specified. Or did you forget to? Try get_galore_param_groups" + ) + + # Defaults from the paper + for group in self.param_groups: + if "rank" in group: + group["update_proj_gap"] = group.get("update_proj_gap", 200) + group["proj_type"] = group.get("proj_type", "std") + group["scale"] = group.get("scale", 0.25) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self.initialized: + self.check_overrides() + self.to_gpu() # needed for fairseq pure fp16 training + self.initialized = True + + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group["params"]): + if p.grad is None: + continue + if p is self.param_groups[0]["params"][0]: + torch.save(p.grad, "grad.pt") + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # GaLore Projection + if "rank" in group: + if "projector" not in state: + state["projector"] = GaLoreProjector( + group["rank"], + scale=group["scale"], + update_proj_gap=group["update_proj_gap"], + proj_type=group["proj_type"], + ) + + if "weight_decay" in group and group["weight_decay"] > 0: + # ensure that the weight decay is not applied to the norm grad + group["weight_decay_saved"] = group["weight_decay"] + group["weight_decay"] = 0 + + grad = state["projector"].project(p.grad, state["step"]) + make_low_rank_buffer(p, grad) + + if "state1" not in state: + self.init_state(group, p, gindex, pindex) + + # p.grad = p.grad.contiguous() # avoid bitsandbytes update error + # Prefetch if paged + self.prefetch_state(p) + # Adam update step using the buffer + self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() + + # GaLore Projection Back + if "rank" in group: + if p is self.param_groups[0]["params"][1]: + pass + update = state["projector"].project_back(p.data) + p.data = p.saved_data.add_(update) + + # apply weight decay + if "weight_decay_saved" in group: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay_saved"]) + group["weight_decay"] = group["weight_decay_saved"] + del group["weight_decay_saved"] + + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + + return loss + + def __del__(self): + """Avoid buffer memory leak""" + for group in self.param_groups: + for p in group["params"]: + if hasattr(p, "saved_data"): + del p.saved_data diff --git a/colossalai/nn/optimizer/lamb.py b/colossalai/nn/optimizer/lamb.py index 0d742487f..eee93fb69 100644 --- a/colossalai/nn/optimizer/lamb.py +++ b/colossalai/nn/optimizer/lamb.py @@ -26,7 +26,9 @@ class Lamb(Optimizer): https://arxiv.org/abs/1904.00962 """ - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False, bias_correction=False + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -35,7 +37,7 @@ class Lamb(Optimizer): raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) self.adam = adam super(Lamb, self).__init__(params, defaults) @@ -79,12 +81,15 @@ class Lamb(Optimizer): # v_t exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - # Paper v3 does not use debiasing. - # bias_correction1 = 1 - beta1 ** state['step'] - # bias_correction2 = 1 - beta2 ** state['step'] - # Apply bias to lr to avoid broadcast. - # * math.sqrt(bias_correction2) / bias_correction1 - step_size = group["lr"] + # NOTE: Paper v3 does not use debiasing. + scaled_lr = group["lr"] + if group["bias_correction"]: + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + # Apply debiasing to lr to avoid broadcast + scaled_lr *= (bias_correction2**0.5) / bias_correction1 + # exp_avg.div_(bias_correction1) + # exp_avg_sq.div_(bias_correction2) weight_norm = p.data.pow(2).sum().sqrt() @@ -97,12 +102,10 @@ class Lamb(Optimizer): trust_ratio = 1 else: trust_ratio = weight_norm / adam_norm - state["weight_norm"] = weight_norm - state["adam_norm"] = adam_norm - state["trust_ratio"] = trust_ratio + if self.adam: trust_ratio = 1 - p.data.add_(adam_step, alpha=-step_size * trust_ratio) + p.data.add_(adam_step, alpha=-scaled_lr * trust_ratio) return loss diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index d5f00fc9f..93da71abb 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -16,7 +16,7 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention +from ..layer import ColoAttention, cross_entropy_1d logger = logging.get_logger(__name__) @@ -270,11 +270,21 @@ class MistralForwards: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + else: + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] @@ -609,3 +619,100 @@ def get_mistral_flash_attention_forward(shard_config: ShardConfig): return attn_output, None, past_key_value return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import MistralForCausalLM + + def forward( + self: MistralForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 81521c30b..5282e2eaa 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -22,6 +22,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig +from ..layer import cross_entropy_1d + logger = logging.get_logger(__name__) @@ -336,8 +338,22 @@ class OPTPipelineForwards: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + else: + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -844,3 +860,146 @@ def get_jit_fused_opt_decoder_layer_forward(): return outputs return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + def forward( + self: OPTForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py new file mode 100644 index 000000000..8f8ab25a5 --- /dev/null +++ b/colossalai/shardformer/modeling/qwen2.py @@ -0,0 +1,758 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) + +try: + from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2ForCausalLM, + Qwen2ForSequenceClassification, + Qwen2Model, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, + apply_rotary_pos_emb, + repeat_kv, + ) +except ImportError: + Qwen2Model = "Qwen2Model" + Qwen2ForCausalLM = "Qwen2ForCausalLM" + Qwen2Attention = "Qwen2Attention" + Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification" + +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.shard import ShardConfig + +from ..layer import ColoAttention, cross_entropy_1d + + +class Qwen2PipelineForwards: + """ + This class serves as a micro library for forward function substitution of Qwen2 models + under pipeline setting. + """ + + @staticmethod + def qwen2_model_forward( + self: Qwen2Model, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + # assert past_key_values is None, "past_key_values is not supported for Qwen2 models at the moment." + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for imediate stage + return {"hidden_states": hidden_states} + + @staticmethod + def qwen2_for_causal_lm_forward( + self: Qwen2ForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = Qwen2PipelineForwards.qwen2_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + if shard_config.enable_tensor_parallelism: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) + else: + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def qwen2_for_sequence_classification_forward( + self: Qwen2ForSequenceClassification, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = Qwen2PipelineForwards.qwen2_model_forward( + self.model, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + batch_size = hidden_states.shape[0] + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + print(self.config.pad_token_id) + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + +def get_qwen2_flash_attention_forward(shard_config: ShardConfig): + def forward( + self: Qwen2Attention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + return forward + + +def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig): + logger = logging.get_logger(__name__) + assert shard_config.enable_flash_attention, "Flash Attention is not enabled." + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds + + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + def forward( + self: Qwen2ForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + if shard_config.enable_tensor_parallelism: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) + else: + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index d2b582af5..69df021b0 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -182,6 +182,16 @@ _POLICY_LIST = { "transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation( file_name="mistral", class_name="MistralForSequenceClassificationPolicy" ), + # Qwen2 + "transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation( + file_name="qwen2", class_name="Qwen2ModelPolicy" + ), + "transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM": PolicyLocation( + file_name="qwen2", class_name="Qwen2ForCausalLMPolicy" + ), + "transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation( + file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy" + ), } diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 984b71646..621982f29 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -18,6 +18,7 @@ from colossalai.shardformer.layer import ( from ..modeling.mistral import ( MistralForwards, + get_lm_forward_with_dist_cross_entropy, get_mistral_flash_attention_forward, get_mistral_model_forward_for_flash_attn, ) @@ -275,14 +276,18 @@ class MistralForCausalLMPolicy(MistralPolicy): SubModuleReplacementDescription( suffix="lm_head", target_module=VocabParallelLMHead1D, - kwargs=dict( - gather_output=True, - make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, - ), + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ] ) } + if self.shard_config.parallel_output: + new_item[MistralForCausalLM].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } else: new_item = { MistralForCausalLM: ModulePolicyDescription( diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9619b3d41..524d2b8cd 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -21,6 +21,7 @@ from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.opt import ( OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, + get_lm_forward_with_dist_cross_entropy, get_opt_decoder_forward_for_flash_attention, get_opt_flash_attention_forward, ) @@ -269,12 +270,18 @@ class OPTForCausalLMPolicy(OPTPolicy): suffix="lm_head", target_module=VocabParallelLMHead1D, kwargs=dict( - gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + gather_output=not self.shard_config.parallel_output, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, ), ), policy=policy, target_key=OPTForCausalLM, ) + if self.shard_config.parallel_output: + method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=OPTForCausalLM + ) else: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py new file mode 100644 index 000000000..3e427c4a1 --- /dev/null +++ b/colossalai/shardformer/policies/qwen2.py @@ -0,0 +1,374 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + RMSNorm, + VocabParallelEmbedding1D, +) + +from ..modeling.qwen2 import ( + Qwen2PipelineForwards, + get_lm_forward_with_dist_cross_entropy, + get_qwen2_flash_attention_forward, + get_qwen2_model_forward_for_flash_attn, +) + +try: + from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2FlashAttention2, + Qwen2ForCausalLM, + Qwen2ForSequenceClassification, + Qwen2Model, + Qwen2SdpaAttention, + ) +except ImportError: + Qwen2ForCausalLM = "Qwen2ForCausalLM" + Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification" + Qwen2Attention = "Qwen2Attention" + Qwen2FlashAttention2 = "Qwen2FlashAttention2" + Qwen2SdpaAttention = "Qwen2SdpaAttention" + Qwen2DecoderLayer = "Qwen2DecoderLayer" + Qwen2Model = "Qwen2Model" + +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"] + + +class Qwen2Policy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + from packaging.version import Version + + assert Version(transformers.__version__) >= Version( + "4.39.1" + ), "The Qwen2 model should run on a transformers version of 4.39.1." + + def config_sanity_check(self): + pass + + def preprocess(self): + self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + ATTN_IMPLEMENTATION = { + "eager": Qwen2Attention, + "flash_attention_2": Qwen2FlashAttention2, + "sdpa": Qwen2SdpaAttention, + } + + policy = {} + + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + norm_cls = FusedRMSNorm if self.shard_config.enable_fused_normalization else RMSNorm + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + + if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + if hasattr(self.model.config, "num_key_value_heads"): + assert ( + self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) + + policy[Qwen2DecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ), + ], + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=Qwen2Model, + ) + + # optimization configuration + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=norm_cls, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=norm_cls, + ), + ], + policy=policy, + target_key=Qwen2DecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=norm_cls, + ), + policy=policy, + target_key=Qwen2Model, + ) + + # use flash attention + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_qwen2_flash_attention_forward(self.shard_config), + }, + policy=policy, + target_key=attn_cls, + ) + if self.pipeline_stage_manager is None: + # replace qwen2 model forward method + self.append_or_create_method_replacement( + description={ + "forward": get_qwen2_model_forward_for_flash_attn(self.shard_config), + }, + policy=policy, + target_key=Qwen2Model, + ) + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "Qwen2Model": + module = self.model + else: + module = self.model.model + + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "Qwen2Model": + module = self.model + else: + module = self.model.model + + stage_manager = self.pipeline_stage_manager + + held_layers = [] + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class Qwen2ModelPolicy(Qwen2Policy): + def module_policy(self): + policy = super().module_policy() + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=Qwen2Model, new_forward=Qwen2PipelineForwards.qwen2_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in Qwen2 model""" + return [] + + +class Qwen2ForCausalLMPolicy(Qwen2Policy): + def module_policy(self): + policy = super().module_policy() + setattr(self.shard_config, "causal_lm", True) + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + Qwen2ForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col) + ], + method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=Qwen2ForCausalLM, new_forward=Qwen2PipelineForwards.qwen2_for_causal_lm_forward, policy=policy + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + qwen2_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(qwen2_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: qwen2_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + + +class Qwen2ForSequenceClassificationPolicy(Qwen2Policy): + def module_policy(self): + policy = super().module_policy() + if self.shard_config.enable_tensor_parallelism: + # add a new item for sequence classification + new_item = { + Qwen2ForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + policy.update(new_item) + # to be confirmed + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=Qwen2ForSequenceClassification, + new_forward=Qwen2PipelineForwards.qwen2_for_sequence_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in Qwen2 for sequence classification model""" + return [] diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index 4129ec62e..26290c7ba 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -6,6 +6,7 @@ from .api import ( get_device_mesh, get_global_shape, get_layout, + get_shard_dim_1d, get_sharding_spec, init_as_dtensor, init_tensor_as_customization_distributed, @@ -37,6 +38,7 @@ __all__ = [ "get_device_mesh", "redistribute", "get_layout", + "get_shard_dim_1d", "is_customized_distributed_tensor", "distribute_tensor_with_customization", "init_tensor_as_customization_distributed", diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 725817088..ede59ee89 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -8,6 +8,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.d_tensor.sharding_spec import DimSpec from .layout import Layout from .layout_converter import LayoutConverter @@ -15,6 +16,22 @@ from .sharding_spec import ShardingSpec layout_converter = LayoutConverter() +_SHARD_DIM = DimSpec([0]) + + +def get_shard_dim_1d(p: torch.Tensor): + """ + Get the dimension along which the tensor is sharded, for example in 1D Tensor Parallel. + Args: + p (torch.Tensor): the input tensor + Returns: + int: the dimension along which the tensor is sharded + """ + if not is_distributed_tensor(p): + raise ValueError("p is not a distributed tensor") + sharding = p.dist_layout.sharding_spec.sharding_sequence + return sharding.index(_SHARD_DIM) + def clear_layout_converter(): global layout_converter diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py index 2ac0ca73e..16a4f248b 100644 --- a/colossalai/tensor/d_tensor/sharding_spec.py +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -140,8 +140,9 @@ class DimSpec: class ShardingSpec: """ - Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like - [R, R, S0, S1], which means + Sharding spec describes how to shard a tensor with dim_size dimensions. For example for a 3D tensor, the sharding sequence + [R, S0, S1] means not sharding the first dim, sharding the 3rd along the 1st device mesh axis (Process group) + and sharding the 3th dim along the 2nd device mesh axis. Useful for say, 2D Tensor Parallel. Argument: dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 5bc662a61..333a3f224 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -20,7 +20,12 @@ class ChunkManager: init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. """ - def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: + def __init__( + self, + chunk_configuration, + init_device: Optional[torch.device] = None, + reuse_fp16_chunk: bool = True, + ) -> None: self.device = init_device or get_accelerator().get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.kwargs_config = chunk_configuration @@ -33,6 +38,10 @@ class ChunkManager: self.accessed_chunks: Set[Chunk] = set() self.accessed_mem: int = 0 self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0} + self.reuse_fp16_chunk = reuse_fp16_chunk + # Whether model is accumulating gradients, + self.accumulating_grads = False + self.overflow_counter = 0 def register_tensor( self, diff --git a/colossalai/zero/gemini/chunk/utils.py b/colossalai/zero/gemini/chunk/utils.py index 7a2ea3606..049c5c102 100644 --- a/colossalai/zero/gemini/chunk/utils.py +++ b/colossalai/zero/gemini/chunk/utils.py @@ -19,6 +19,7 @@ def init_chunk_manager( model: nn.Module, init_device: Optional[torch.device] = None, hidden_dim: Optional[int] = None, + reuse_fp16_chunk: bool = True, verbose: bool = False, **kwargs, ) -> ChunkManager: @@ -50,5 +51,9 @@ def init_chunk_manager( ) dist.barrier() - chunk_manager = ChunkManager(config_dict, init_device) + chunk_manager = ChunkManager( + config_dict, + init_device, + reuse_fp16_chunk=reuse_fp16_chunk, + ) return chunk_manager diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index b25de1d68..c1029097a 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -98,8 +98,14 @@ class GeminiDDP(ModelWrapper): verbose: bool = False, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) + reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False + self.enable_gradient_accumulation = enable_gradient_accumulation if chunk_config_dict is not None: - self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device) + self.chunk_manager = ChunkManager( + chunk_config_dict, + chunk_init_device, + reuse_fp16_chunk=reuse_fp16_chunk, + ) else: # some ugly hotfix for the compatibility with Lightning if search_range_m is None: @@ -112,6 +118,7 @@ class GeminiDDP(ModelWrapper): min_chunk_size_m=min_chunk_size_m, strict_ddp_flag=strict_ddp_mode, process_group=zero_group, + reuse_fp16_chunk=reuse_fp16_chunk, verbose=verbose, ) self.gemini_manager = GeminiManager( @@ -128,7 +135,6 @@ class GeminiDDP(ModelWrapper): self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() - self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = dict() self.param2name: Dict[nn.Parameter, str] = dict() self.name2param: Dict[str, nn.Parameter] = dict() @@ -137,14 +143,8 @@ class GeminiDDP(ModelWrapper): self.zero_group = zero_group or _get_default_group() self.extra_dp_group = extra_dp_group - self.reuse_fp16_chunk = master_weights self.master_weights = master_weights - self.enable_gradient_accumulation = enable_gradient_accumulation - if self.enable_gradient_accumulation: - self.reuse_fp16_chunk = False - self.accumulating_grads = False # Whether model is accumulating gradients - self._logger = get_dist_logger() if self.gemini_manager._premade_memstats_: @@ -178,7 +178,29 @@ class GeminiDDP(ModelWrapper): if is_ddp_ignored(p): continue if p.requires_grad: - p.register_hook(partial(self.grad_handle, p)) + p._grad_handle = p.register_hook( + partial( + GeminiDDP.grad_handle, + chunk_manager=self.chunk_manager, + param2name=self.param2name, + grads_device=self.grads_device, + master_weights=self.master_weights, + enable_gradient_accumulation=self.enable_gradient_accumulation, + p=p, + ) + ) + + def remove_hooks(self): + for p in self.module.parameters(): + if is_ddp_ignored(p): + continue + if p.requires_grad: + assert hasattr(p, "_grad_handle") + p._grad_handle.remove() + delattr(p, "_grad_handle") + + def __del__(self): + self.remove_hooks() def parameters(self, recurse: bool = True): return self.module.parameters(recurse) @@ -324,8 +346,8 @@ class GeminiDDP(ModelWrapper): f"{error_str}", ) self._setup_grads_ptr() - if self.enable_gradient_accumulation and not self.accumulating_grads: - self.accumulating_grads = True # Turn on the state of gradient accumulation. + if self.enable_gradient_accumulation and not self.chunk_manager.accumulating_grads: + self.chunk_manager.accumulating_grads = True # Turn on the state of gradient accumulation. self._logger.debug( f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}" ) @@ -340,25 +362,34 @@ class GeminiDDP(ModelWrapper): def backward_by_grad(self, tensor, grad): raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") - def grad_handle(self, p, grad): + @staticmethod + def grad_handle( + grad, + chunk_manager: ChunkManager, + param2name: Dict, + grads_device: Dict, + master_weights: bool, + enable_gradient_accumulation: bool, + p: nn.Parameter, + ): setattr(p, "_gemini_reduced", True) empty_grad = torch.empty_like(grad) free_storage(empty_grad) with torch._C.DisableTorchFunction(): - chunk = self.chunk_manager.get_chunk(p) + chunk = chunk_manager.get_chunk(p) if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: raise RuntimeError( - f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " + f"Parameter `{param2name[p]}` failed at the gradient reduction. " "Some unsupported torch function is operated upon this parameter." ) grad_chunk = chunk - if not self.reuse_fp16_chunk: - if not self.accumulating_grads: - grad_chunk = self.chunk_manager.init_grad_chunk(chunk) + if not chunk_manager.reuse_fp16_chunk: + if not chunk_manager.accumulating_grads: + grad_chunk = chunk_manager.init_grad_chunk(chunk) else: assert chunk.grad_chunk is not None - if chunk.grad_chunk not in self.chunk_manager.accessed_chunks: - grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk) + if chunk.grad_chunk not in chunk_manager.accessed_chunks: + grad_chunk = chunk_manager.rearrange_accumulated_grad_chunk(chunk) else: grad_chunk = chunk.grad_chunk chunk.grad_chunk.l2_norm = None @@ -371,33 +402,33 @@ class GeminiDDP(ModelWrapper): chunk.tensor_trans_state(p, TensorState.HOLD) grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE) - if not self.accumulating_grads: - grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk) + if not chunk_manager.accumulating_grads: + grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk) else: grad_chunk.add_tensor_to_chunk_slice(p, grad) - reduced = self.chunk_manager.reduce_chunk(grad_chunk) + reduced = chunk_manager.reduce_chunk(grad_chunk) if reduced: - if not self.reuse_fp16_chunk: + if not chunk_manager.reuse_fp16_chunk: if chunk.keep_gathered: - self.chunk_manager.fake_release_chunk(chunk) + chunk_manager.fake_release_chunk(chunk) else: - self.chunk_manager.release_chunk(chunk) + chunk_manager.release_chunk(chunk) if grad_chunk.is_gathered: grad_chunk.cuda_global_chunk.div_(chunk.pg_size) - if self.extra_dp_group is not None: + if chunk.extra_dp_group is not None: grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) else: grad_chunk.cuda_shard.div_(chunk.pg_size) - if self.extra_dp_group is not None: + if chunk.extra_dp_group is not None: grad_chunk.cuda_shard.div_(chunk.extra_dp_size) # check overflow elements - self.overflow_counter += grad_chunk.has_inf_or_nan + chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan # record l2 norm for gradient clipping. flag is bound to fp16 chunk if chunk.l2_norm_flag: grad_chunk.set_l2_norm() - self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True) - if not (self.master_weights) or (self.enable_gradient_accumulation): - self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) + chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) + if not (master_weights) or (enable_gradient_accumulation): + chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) return empty_grad def zero_grad(self, set_to_none: bool = False) -> None: @@ -513,11 +544,11 @@ class GeminiDDP(ModelWrapper): # get copies of fp32 parameters in CPU # as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16 - params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params + params = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params param_to_save_data = self._get_param_to_save_data(params, only_rank_0) # get the mapping between copies and fp16 parameters p_mapping = dict() - if self.reuse_fp16_chunk: + if self.chunk_manager.reuse_fp16_chunk: for p, fp32_p in zip(self.fp16_params, self.fp32_params): name = self.param2name[p] assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) @@ -713,7 +744,7 @@ class GeminiDDP(ModelWrapper): name = self.param2name[p] fp32_to_name[fp32_p] = name - params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params + params_to_load = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params chunk_list = self.chunk_manager.get_chunks(params_to_load) for chunk in chunk_list: temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) @@ -728,7 +759,9 @@ class GeminiDDP(ModelWrapper): shard_fn = tensor.shard_fn gather_fn = tensor.gather_fn - parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor] + parameter_name = ( + fp32_to_name[tensor] if self.chunk_manager.reuse_fp16_chunk else self.param2name[tensor] + ) parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] load( parameter_name, @@ -900,7 +933,7 @@ class GeminiDDP(ModelWrapper): gathered_param = param if keep_vars else param.detach() else: # as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16 - param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param + param_to_save = fp16_to_fp32[param] if self.chunk_manager.reuse_fp16_chunk else param if param_to_save not in gathered_param_buffer: chunk = self.chunk_manager.get_chunk(param_to_save) gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index ae02fe297..18918eabc 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): self.module = module def check_local_overflow(self) -> bool: - return self.module.overflow_counter > 0 + return self.module.chunk_manager.overflow_counter > 0 def pre_zero_grad(self) -> None: - self.module.overflow_counter = 0 + self.module.chunk_manager.overflow_counter = 0 class GeminiOptimizer(OptimizerWrapper): @@ -202,7 +202,7 @@ class GeminiOptimizer(OptimizerWrapper): chunk16 = self.param_to_chunk16[fake_param] begin, end = self.param_to_range[fake_param] - grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk + grad_chunk16 = chunk16 if self.module.chunk_manager.reuse_fp16_chunk else chunk16.grad_chunk fake_param.data = grad_chunk16.payload[begin:end] fake_param.grad = fake_param.data @@ -221,14 +221,14 @@ class GeminiOptimizer(OptimizerWrapper): def _clear_global_norm(self) -> None: for c16 in self.chunk16_set: - grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk + grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk grad_chunk.l2_norm = None def _calc_global_norm(self) -> float: norm_sqr: float = 0.0 group_to_norm = dict() for c16 in self.chunk16_set: - grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk + grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk assert grad_chunk.l2_norm is not None if grad_chunk.is_gathered: @@ -275,7 +275,7 @@ class GeminiOptimizer(OptimizerWrapper): self._logger.info(f"Found overflow. Skip step") self._clear_global_norm() # clear recorded norm self.zero_grad() # reset all gradients - if self.module.reuse_fp16_chunk: + if self.module.chunk_manager.reuse_fp16_chunk: self._update_fp16_params() return @@ -288,7 +288,7 @@ class GeminiOptimizer(OptimizerWrapper): self.zero_grad() if self.module.master_weights: self._update_fp16_params() - self.module.accumulating_grads = False + self.module.chunk_manager.accumulating_grads = False return ret def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): diff --git a/colossalai/zero/low_level/bookkeeping/base_store.py b/colossalai/zero/low_level/bookkeeping/base_store.py index 107d62dcb..7f2f9664b 100644 --- a/colossalai/zero/low_level/bookkeeping/base_store.py +++ b/colossalai/zero/low_level/bookkeeping/base_store.py @@ -6,6 +6,7 @@ class BaseStore: def __init__(self, torch_pg: ProcessGroup): self._world_size = dist.get_world_size(group=torch_pg) self._local_rank = dist.get_rank(group=torch_pg) + self.torch_pg = torch_pg @property def world_size(self): diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 2ebc704f7..1496603fa 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -1,16 +1,43 @@ -from typing import Dict +from typing import Dict, Optional import torch +import torch.distributed as dist from torch import Tensor from torch._utils import _flatten_dense_tensors from torch.distributed import ProcessGroup +from colossalai.accelerator import get_accelerator + from .base_store import BaseStore class BucketStore(BaseStore): - def __init__(self, torch_pg: ProcessGroup): + def __init__( + self, + torch_pg: ProcessGroup, + reduce_bucket_size: int, + overlap_communication: bool, + communication_dtype: Optional[torch.dtype] = None, + moe_extra_dp_process_group: ProcessGroup = None, + ): super().__init__(torch_pg) + self.reduce_bucket_size = reduce_bucket_size + # communication params + self._overlap_communication = overlap_communication + self._communication_dtype = communication_dtype + if self._overlap_communication: + self.comm_stream = get_accelerator().Stream() + self.zero_local_rank = dist.get_rank(group=self.torch_pg) + self.zero_world_size = dist.get_world_size(group=self.torch_pg) + # extra dp + # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size. + # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. + # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step. + # And moe working and master param are split by extra dp pg. + self.moe_extra_dp_pg = moe_extra_dp_process_group + if self.moe_extra_dp_pg is not None: + self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg) + self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg) self.reset_all() def reset_all(self) -> None: diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 6d4fcbb86..fc28b7795 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -6,7 +6,7 @@ from .base_store import BaseStore class GradientStore(BaseStore): - def __init__(self, *args, partition_grad: bool = False): + def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True): super().__init__(*args) """ self._grads_of_params mapping the parameter and its gradient slices @@ -18,9 +18,12 @@ class GradientStore(BaseStore): } """ self._grads_of_params = dict() - # for zero2, it's `param_id: [grad_local_rank]` + # stage 2 + self._partition_grads = partition_grad + # grad accumulation + self.require_grad_sync = require_grad_sync self._working_index = 0 if partition_grad else self._local_rank - + # for zero2, it's `param_id: [grad_local_rank]` self.grad_to_param_mapping = dict() def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: diff --git a/colossalai/zero/low_level/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py index e94fb4de9..c03231f5f 100644 --- a/colossalai/zero/low_level/bookkeeping/parameter_store.py +++ b/colossalai/zero/low_level/bookkeeping/parameter_store.py @@ -1,3 +1,5 @@ +from typing import Dict + from torch import Tensor from torch.distributed import ProcessGroup @@ -47,3 +49,12 @@ class ParameterStore(BaseStore): self.master_to_working_param[id(master_param)] = working_param self.working_to_master_param[id(working_param)] = master_param + + def get_padding_map(self) -> Dict[int, Tensor]: + """Return the padding map + + Returns: + Dict[int, Tensor]: The padding map + """ + + return self._padding_map diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 345dfde73..5f7f2a4e2 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -90,38 +90,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._logger = get_dist_logger() self._verbose = verbose - # stage 2 - self._partition_grads = partition_grad - self._cpu_offload = cpu_offload - # grad accumulation - self.require_grad_sync = True - - # if process_group is none, will use the default one - self.dp_pg = dp_process_group - self._local_rank = dist.get_rank(group=self.dp_pg) - self._world_size = dist.get_world_size(group=self.dp_pg) - - # extra dp - # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size. - # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. - # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step. - # And moe working and master param are split by extra dp pg. - self.moe_extra_dp_pg = moe_extra_dp_process_group - if self.moe_extra_dp_pg is not None: - self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg) - self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg) - # working and master params for mixed precision training self._working_param_groups = dict() self._master_param_groups_of_current_rank = dict() - # communication params - self._overlap_communication = overlap_communication - self._reduce_bucket_size = reduce_bucket_size - self._communication_dtype = communication_dtype - # gradient clipping self._clip_grad_norm = clip_grad_norm @@ -140,9 +114,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # ParameterStore will manage the tensor buffers used for zero # it will not manage the tensors used by mixed precision training - self._param_store = ParameterStore(self.dp_pg) - self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad) - self._bucket_store = BucketStore(self.dp_pg) + self._param_store = ParameterStore(dp_process_group) + self._grad_store = GradientStore(dp_process_group, partition_grad=partition_grad, require_grad_sync=True) + self._bucket_store = BucketStore( + dp_process_group, reduce_bucket_size, overlap_communication, communication_dtype, moe_extra_dp_process_group + ) # moe param should not be stored in working_groups # because they have different parallel strategy @@ -157,7 +133,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): group_params = list() for param in param_group["params"]: if param.requires_grad: - if self.moe_extra_dp_pg is None: + if self._bucket_store.moe_extra_dp_pg is None: # skip moe param if is_moe_tensor(param): self.working_moe_params.append(param) @@ -194,15 +170,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): param_group["params"] = self.master_moe_params self.optim.param_groups.append(param_group) - # initialize communication stream for - # communication-computation overlapping - if self._overlap_communication: - self._comm_stream = get_accelerator().Stream() - # reduction hook is only used if overlapping communication # or stage 2 is used # if it is stage 1 without overlapping, no hook will be attached - if self._overlap_communication or self._partition_grads: + if self._bucket_store._overlap_communication or self._grad_store._partition_grads: self._attach_reduction_hook() # initialize mixed precision mixin @@ -222,6 +193,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): elif self._dtype is torch.bfloat16: self.mixed_precision_mixin = BF16MixedPrecisionMixin() + def __del__(self): + self.remove_hooks() + @property def dtype(self): return self._dtype @@ -246,7 +220,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() for param in param_list: - padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size + padding_size = ( + self._bucket_store.zero_world_size - param.numel() % self._bucket_store.zero_world_size + ) % self._bucket_store.zero_world_size self._param_store.record_param_padding_size(param, padding_size) with torch.no_grad(): @@ -258,12 +234,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): else: padding_param = param.data.view(-1) - if self.moe_extra_dp_pg is not None and is_moe_tensor(param): - splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size) - splited_params = splited_params[self.moe_extra_dp_pg_rank] + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(param): + splited_params = padding_param.split( + padding_param.numel() // self._bucket_store.moe_extra_dp_pg_size + ) + splited_params = splited_params[self._bucket_store.moe_extra_dp_pg_rank] else: - splited_params = padding_param.split(padding_param.numel() // self._world_size) - splited_params = splited_params[self._local_rank] + splited_params = padding_param.split(padding_param.numel() // self._bucket_store.zero_world_size) + splited_params = splited_params[self._bucket_store.zero_local_rank] # use fp32 when master_weights is True if self._master_weights is True: @@ -271,6 +249,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): else: splited_param_current_rank = splited_params + # Send the splited view to the optimizer to match ZeRO 2 grad shape params_current_rank.append(splited_param_current_rank) self._param_store.link_master_and_working_param(splited_param_current_rank, param) @@ -280,10 +259,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # Backward Reduction Hook # ########################### - def _grad_handler(self, group_id, param): + @staticmethod + def grad_handler( + param: nn.Parameter, + group_id: int, + bucket_store: BucketStore, + param_store: ParameterStore, + grad_store: GradientStore, + ): # if run with no_sync context, would not sync grad when backward - if self.require_grad_sync: - self._add_to_bucket(param, group_id) + if grad_store.require_grad_sync: + LowLevelZeroOptimizer.add_to_bucket(param, group_id, bucket_store, param_store, grad_store) def _attach_reduction_hook(self): # we iterate over the working params @@ -292,29 +278,36 @@ class LowLevelZeroOptimizer(OptimizerWrapper): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad: - param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id)) + param._grad_handle = param.register_post_accumulate_grad_hook( + partial( + LowLevelZeroOptimizer.grad_handler, + group_id=group_id, + bucket_store=self._bucket_store, + param_store=self._param_store, + grad_store=self._grad_store, + ) + ) ####################### # Reduction Functions # ####################### - - def _run_reduction(self): - if self._bucket_store.num_elements_in_bucket() > 0: - self._bucket_store.build_grad_in_bucket() - - if self.moe_extra_dp_pg is None: - flat_grads = self._bucket_store.get_flatten_grad() - flat_grads /= self._world_size + @staticmethod + def run_reduction(bucket_store: BucketStore, grad_store: GradientStore): + if bucket_store.num_elements_in_bucket() > 0: + bucket_store.build_grad_in_bucket() + if bucket_store.moe_extra_dp_pg is None: + flat_grads = bucket_store.get_flatten_grad() + flat_grads /= bucket_store.zero_world_size else: # record moe and non moe param moe_list = [] - for param in self._bucket_store._param_list: + for param in bucket_store._param_list: moe_list.append(is_moe_tensor(param)) # divide them into different groups moe_grad_list = [] non_moe_grad_list = [] - for grad_list in self._bucket_store._grad_in_bucket.values(): + for grad_list in bucket_store._grad_in_bucket.values(): non_moe_cur_grad = [] moe_cur_grad = [] for i in range(len(grad_list)): @@ -332,7 +325,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for grad_list in non_moe_grad_list: non_moe_flat_grads.append(_flatten_dense_tensors(grad_list)) non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads) - non_moe_flat_grads /= self._world_size + non_moe_flat_grads /= bucket_store.zero_world_size if len(moe_grad_list) > 0: moe_flat_grads = [] @@ -341,12 +334,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper): moe_flat_grads = _flatten_dense_tensors(moe_flat_grads) # ready to add other tensors to bucket - self._bucket_store.reset_num_elements_in_bucket() + bucket_store.reset_num_elements_in_bucket() - if self._overlap_communication: - stream = self._comm_stream + if bucket_store._overlap_communication: + stream = bucket_store.comm_stream # in case of the memory being reused in the default stream - if self.moe_extra_dp_pg is None: + if bucket_store.moe_extra_dp_pg is None: flat_grads.record_stream(stream) else: if len(non_moe_grad_list) > 0: @@ -359,53 +352,63 @@ class LowLevelZeroOptimizer(OptimizerWrapper): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - group_id = self._bucket_store.current_group_id + group_id = bucket_store.current_group_id - if self.moe_extra_dp_pg is None: + if bucket_store.moe_extra_dp_pg is None: grad_dtype = flat_grads.dtype - if self._communication_dtype is not None: - flat_grads = flat_grads.to(self._communication_dtype) + if bucket_store._communication_dtype is not None: + flat_grads = flat_grads.to(bucket_store._communication_dtype) - if not self._partition_grads: - if self.moe_extra_dp_pg is None: - dist.all_reduce(flat_grads, group=self.dp_pg) + if not grad_store._partition_grads: + if bucket_store.moe_extra_dp_pg is None: + dist.all_reduce(flat_grads, group=bucket_store.torch_pg) if flat_grads.dtype != grad_dtype: flat_grads = flat_grads.to(grad_dtype) - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) - grad_in_bucket = self._bucket_store.get_grad() - self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.zero_world_size) + grad_in_bucket = bucket_store.get_grad() + LowLevelZeroOptimizer.update_unpartitoned_grad( + bucket_store, grad_store, grad_in_bucket.values(), flat_grads_per_rank, group_id + ) # sync extra zero group else: # sync non moe param in global dp group if len(non_moe_grad_list) > 0: - dist.all_reduce(non_moe_flat_grads, group=self.dp_pg) + dist.all_reduce(non_moe_flat_grads, group=bucket_store.torch_pg) flat_grads_per_rank = non_moe_flat_grads.split( - non_moe_flat_grads.numel() // self._world_size + non_moe_flat_grads.numel() // bucket_store.zero_world_size + ) + LowLevelZeroOptimizer.update_unpartitoned_grad( + bucket_store, grad_store, non_moe_grad_list, flat_grads_per_rank, group_id ) - self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) # sync moe param only in zero group if len(moe_grad_list) > 0: - dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg) - flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size) - self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id) + dist.all_reduce(moe_flat_grads, group=bucket_store.moe_extra_dp_pg) + flat_grads_per_rank = moe_flat_grads.split( + moe_flat_grads.numel() // bucket_store.zero_world_size + ) + LowLevelZeroOptimizer.update_unpartitoned_grad( + bucket_store, grad_store, moe_grad_list, flat_grads_per_rank, group_id + ) else: - if self.moe_extra_dp_pg is None: - flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + if bucket_store.moe_extra_dp_pg is None: + flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.zero_world_size)) + received_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) - if recieved_grad.dtype != grad_dtype: - recieved_grad = recieved_grad.to(grad_dtype) + if received_grad.dtype != grad_dtype: + received_grad = received_grad.to(grad_dtype) - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) + grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank] + LowLevelZeroOptimizer.update_partitoned_grad( + bucket_store, grad_store, grad_in_bucket_current_rank, received_grad, group_id, 1 + ) else: # categorize moe and non moe param - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank] moe_grad_in_bucket_current_rank = [] non_moe_grad_in_bucket_current_rank = [] for idx, grad in enumerate(grad_in_bucket_current_rank): @@ -416,48 +419,61 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if len(non_moe_grad_list) > 0: flat_grads_list = list( - non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size) + non_moe_flat_grads.split(len(non_moe_flat_grads) // bucket_store.zero_world_size) ) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - self._update_partitoned_grad( + received_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) + LowLevelZeroOptimizer.update_partitoned_grad( + bucket_store, + grad_store, non_moe_grad_in_bucket_current_rank, - recieved_grad, + received_grad, group_id, 1, ) if len(moe_grad_list) > 0: flat_grads_list = list( - moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size) + moe_flat_grads.split(len(moe_flat_grads) // bucket_store.moe_extra_dp_pg_size) ) - recieved_grad = torch.zeros_like(flat_grads_list[0]) + received_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter( - recieved_grad, + received_grad, flat_grads_list, - group=self.moe_extra_dp_pg, + group=bucket_store.moe_extra_dp_pg, ) - param_slice = self._world_size // self.moe_extra_dp_pg_size - recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) - for split_recieved_grad in recieved_grad: + param_slice = bucket_store.zero_world_size // bucket_store.moe_extra_dp_pg_size + received_grad = list(received_grad.split(len(received_grad) // param_slice)) + for split_recieved_grad in received_grad: split_recieved_grad = _unflatten_dense_tensors( split_recieved_grad, moe_grad_in_bucket_current_rank ) for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank): - param_id = self._bucket_store.get_param_id_of_grad(grad) - self._add_grad(real_grad, param_slice, group_id, param_id) + param_id = bucket_store.get_param_id_of_grad(grad) + LowLevelZeroOptimizer.add_grad( + grad_store, real_grad, param_slice, group_id, param_id + ) - self._bucket_store.reset() + bucket_store.reset() - def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None: + @staticmethod + def update_unpartitoned_grad( + bucket_store: BucketStore, + grad_store: GradientStore, + origin_grad_list: List, + flat_grad_list: List, + group_id: int, + ) -> None: for rank, grad_list in enumerate(origin_grad_list): sync_tensor(flat_grad_list[rank], grad_list) for grad in grad_list: - param_id = self._bucket_store.get_param_id_of_grad(grad) - self._add_grad(grad, self._world_size, group_id, param_id, rank) + param_id = bucket_store.get_param_id_of_grad(grad) + LowLevelZeroOptimizer.add_grad(grad_store, grad, bucket_store.zero_world_size, group_id, param_id, rank) - def _update_partitoned_grad( - self, + @staticmethod + def update_partitoned_grad( + bucket_store: BucketStore, + grad_store: GradientStore, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, @@ -465,23 +481,31 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ) -> None: sync_tensor(flat_grad, origin_grad_list) for grad in origin_grad_list: - param_id = self._bucket_store.get_param_id_of_grad(grad) - self._add_grad(grad, partition_num, group_id, param_id) + param_id = bucket_store.get_param_id_of_grad(grad) + LowLevelZeroOptimizer.add_grad(grad_store, grad, partition_num, group_id, param_id) - def _add_grad( - self, + @staticmethod + def add_grad( + grad_store: GradientStore, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0, ) -> None: - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + if len(grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: + grad_store.append_gradients_by_param_id(grad, group_id, param_id) else: - self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) - def _add_to_bucket(self, param, group_id): + @staticmethod + def add_to_bucket( + param: nn.Parameter, + group_id: int, + bucket_store: BucketStore, + param_store: ParameterStore, + grad_store: GradientStore, + ): param_size = param.numel() # check if the bucket is full @@ -489,13 +513,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # or got a grad of param from another group # after reduction, the bucket will be empty if ( - self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size - or group_id != self._bucket_store.current_group_id + bucket_store.num_elements_in_bucket() + param_size > bucket_store.reduce_bucket_size + or group_id != bucket_store.current_group_id ): - self._run_reduction() + LowLevelZeroOptimizer.run_reduction(bucket_store, grad_store) - padding_size = self._param_store.get_param_padding_size(param) - self._bucket_store.add_param_grad(group_id, param, padding_size) + padding_size = param_store.get_param_padding_size(param) + bucket_store.add_param_grad(group_id, param, padding_size) ################################ # torch.optim.Optimizer methods @@ -503,7 +527,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def backward(self, loss, retain_graph=False): assert not ( - self._partition_grads and not self.require_grad_sync + self._grad_store._partition_grads and not self._grad_store.require_grad_sync ), "ZeRO2(partition_grads) and no_sync are not compatible" if self.mixed_precision_mixin is not None: @@ -511,31 +535,31 @@ class LowLevelZeroOptimizer(OptimizerWrapper): loss.backward(retain_graph=retain_graph) - if not self.require_grad_sync: + if not self._grad_store.require_grad_sync: return - self._reduce_grad(self._partition_grads) + self._reduce_grad(self._grad_store._partition_grads) # clear reduced grads - if self._overlap_communication: + if self._bucket_store._overlap_communication: get_accelerator().synchronize() self.zero_grad() def backward_by_grad(self, tensor, grad): assert not ( - self._partition_grads and not self.require_grad_sync + self._grad_store._partition_grads and not self._grad_store.require_grad_sync ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) torch.autograd.backward(tensor, grad) - if not self.require_grad_sync: + if not self._grad_store.require_grad_sync: return - self._reduce_grad(self._partition_grads) + self._reduce_grad(self._grad_store._partition_grads) # clear reduced grads - if self._overlap_communication: + if self._bucket_store._overlap_communication: get_accelerator().synchronize() self.zero_grad() @@ -566,7 +590,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def step(self, closure=None): assert closure is None, "closure is not supported by step()" - if not self.require_grad_sync: + if not self._grad_store.require_grad_sync: return if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): @@ -585,7 +609,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # and should not be updated real_working_params = dict() real_master_params = dict() - grad_index = 0 if self._partition_grads else self._local_rank + grad_index = 0 if self._grad_store._partition_grads else self._bucket_store.zero_local_rank for group_id in range(self.num_param_groups): master_params = self._master_param_groups_of_current_rank[group_id] real_working_params[group_id] = [] @@ -598,14 +622,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper): grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) if len(grads) > 0: # moe hybrid zero - if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param): real_working_params[group_id].append(working_param) - if self._partition_grads: + if self._grad_store._partition_grads: grad = grads else: - param_slice = self._world_size // self.moe_extra_dp_pg_size + param_slice = self._bucket_store.zero_world_size // self._bucket_store.moe_extra_dp_pg_size grad = grads[ - self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice + self._bucket_store.moe_extra_dp_pg_rank + * param_slice : (self._bucket_store.moe_extra_dp_pg_rank + 1) + * param_slice ] grad = flatten(grad) else: @@ -674,25 +700,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] - if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param): all_splited_param = [ torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self.moe_extra_dp_pg_size) + for _ in range(self._bucket_store.moe_extra_dp_pg_size) ] dist.all_gather( all_splited_param, splited_param.to(device).to(self._dtype), - group=self.moe_extra_dp_pg, + group=self._bucket_store.moe_extra_dp_pg, ) else: all_splited_param = [ torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self._world_size) + for _ in range(self._bucket_store.zero_world_size) ] dist.all_gather( all_splited_param, splited_param.to(device).to(self._dtype), - group=self.dp_pg, + group=self._bucket_store.torch_pg, ) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] @@ -720,7 +746,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): device=get_accelerator().get_current_device(), dtype=torch.float, ) - dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self._bucket_store.torch_pg) total_norm = total_norm_cuda.item() else: @@ -738,7 +764,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): torch.distributed.all_reduce( total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, - group=self.dp_pg, + group=self._bucket_store.torch_pg, ) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) @@ -773,27 +799,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad and param.grad is not None: - self._add_to_bucket(param, group_id) + LowLevelZeroOptimizer.add_to_bucket( + param, + group_id, + self._bucket_store, + self._param_store, + self._grad_store, + ) - self._run_reduction() + LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store) def _reduce_grad(self, partition_grad): # if not overlapping communication (no reduction hook is attached) when zero1 # we need to manually reduce these gradients - if not partition_grad and not self._overlap_communication: + if not partition_grad and not self._bucket_store._overlap_communication: self._sync_grad() else: - self._run_reduction() + LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store) # this context comes from pytorch DDP @contextmanager def no_sync(self): - old_require_grad_sync = self.require_grad_sync - self.require_grad_sync = False + old_require_grad_sync = self._grad_store.require_grad_sync + self._grad_store.require_grad_sync = False try: yield finally: - self.require_grad_sync = old_require_grad_sync + self._grad_store.require_grad_sync = old_require_grad_sync ############## # State Dict # @@ -833,16 +865,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": working_param = self._param_store.master_to_working_param[id(param)] - if self.moe_extra_dp_pg is not None and is_moe_tensor(v): + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): gather_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) + for _ in range(self._bucket_store.moe_extra_dp_pg_size) ] - dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg) + dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg) else: gather_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) + for _ in range(self._bucket_store.zero_world_size) ] - dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg) + dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.torch_pg) param_state = ( torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -862,17 +896,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for param_idx, state in zero_state_dict["state"].items(): for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": - padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size + padding_size = ( + self._bucket_store.zero_world_size - v.numel() % self._bucket_store.zero_world_size + ) % self._bucket_store.zero_world_size with torch.no_grad(): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - if self.moe_extra_dp_pg is not None and is_moe_tensor(v): - v_list = v.split(v.numel() // self.moe_extra_dp_pg_size) - zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone() + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): + v_list = v.split(v.numel() // self._bucket_store.moe_extra_dp_pg_size) + zero_state_dict["state"][param_idx][k] = ( + v_list[self._bucket_store.moe_extra_dp_pg_rank].detach().clone() + ) else: - v_list = v.split(v.numel() // self._world_size) - zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() + v_list = v.split(v.numel() // self._bucket_store.zero_world_size) + zero_state_dict["state"][param_idx][k] = ( + v_list[self._bucket_store.zero_local_rank].detach().clone() + ) self.optim.load_state_dict(zero_state_dict) @@ -904,16 +944,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": - if self.moe_extra_dp_pg is not None and is_moe_tensor(v): + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v): state_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) + for _ in range(self._bucket_store.moe_extra_dp_pg_size) ] - dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg) + dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg) else: state_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) + for _ in range(self._bucket_store.zero_world_size) ] - dist.all_gather(state_tensor, v.to(device), group=self.dp_pg) + dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.torch_pg) state_tensor = ( torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -944,14 +986,30 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param = p.data.view(-1) if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - if self.moe_extra_dp_pg is not None and is_moe_tensor(p): + if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(p): master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) else: - master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + master_param.copy_( + working_param.chunk(self._bucket_store.zero_world_size)[self._bucket_store.zero_local_rank] + ) if hasattr(self, "master_moe_params"): for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): master_moe_param.copy_(working_moe_param) + def remove_hooks(self) -> None: + """remove the registered hooks + + Args: + plugin (LowLevelZeroPlugin): the plugin to bound this method. + """ + for group_id in range(self.num_param_groups): + param_group = self._working_param_groups[group_id] + for param in param_group: + if param.requires_grad: + assert hasattr(param, "_grad_handle") + param._grad_handle.remove() + delattr(param, "_grad_handle") + def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self._param_store.working_to_master_param @@ -962,3 +1020,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): **self.moe_master_to_working_map, } return self._param_store.master_to_working_param + + def get_param_padding_map(self) -> Dict[int, torch.Tensor]: + return self._param_store.get_padding_map() diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 2e5437752..41110612c 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -413,7 +413,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 环境要求: -- PyTorch >= 1.11 并且 PyTorch <= 2.1 +- PyTorch >= 2.1 - Python >= 3.7 - CUDA >= 11.0 - [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher) diff --git a/docs/source/en/features/distributed_optimizers.md b/docs/source/en/features/distributed_optimizers.md new file mode 100644 index 000000000..7590669df --- /dev/null +++ b/docs/source/en/features/distributed_optimizers.md @@ -0,0 +1,141 @@ +# Distributed Optimizers + +Author: [Wenxuan Tan](https://github.com/Edenzzzz), [Junwen Duan](https://github.com/duanjunwen), [Renjie Mao](https://github.com/chongqichuizi875) + +**Related Paper** +- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) +- [CAME: Confidence-guided Adaptive Memory Efficient Optimization] (https://arxiv.org/abs/2307.02047) +- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection] (https://arxiv.org/abs/2403.03507) +- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962) + +## Introduction +Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to efficiently update parameters, and are thus not directly applicable to parallel settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO using plugins. +## Optimizers +Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant. + +## API Reference + +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} +{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }} +{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }} +{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }} + +## Hands-On Practice +We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. +### step 1. Import libraries + +```python +from transformers import LlamaModel, LlamaConfig +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +import colossalai +import torch +``` + +### step 2. Initialize Distributed Environment and Parallism Group +We need to initialize distributed environment. For demo purpose, we use `colossal run --nproc_per_node 4`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) + +```python +colossalai.launch_from_torch() +``` + +### step 3. Initialize Module and Optimizer +Build our model. We created an MLP using two Linear Layer. + +```python +# Init Llama from huggingface +configuration = LlamaConfig() +model = LlamaModel(configuration).cuda() +criterion = lambda x: x.mean() +dist_optim = DistributedAdaFactor(model.parameters()) + +``` + +### step 4.Init Booster + +```python +plugin = HybridParallelPlugin(tp_size=2, zero_stage=2, pp_size=1, enable_all_optimization=True) +booster = Booster(plugin=plugin) +# You should also pass in your own dataset. +model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion) +``` +### step 5.Train Your Model +```python +steps = 10 +for step in range(steps): + input_ids = torch.ones(1, 100, device="cuda", dtype=torch.int) + attention_mask = input_ids.clone() + outputs = model(input_ids.cuda(), attention_mask.cuda()) + loss = criterion(outputs.last_hidden_state) + booster.backward(loss, dist_optim) + dist_optim.step() + dist_optim.zero_grad() +``` +### GaLore special handling +For GaLore, we need to specify projection rank for each parameter group and quantization & paged optimizer params. Please refer to bitandbytes for quantization details. Support for ZeRO is underway. +```python +from colossalai.nn.optimizer.galore import get_galore_param_groups +from colossalai.nn.optimizer import DistGaloreAwamW +optim = DistGaloreAwamW( + get_galore_param_groups(model, decay=1e-2, rank=8), + lr=lr, + betas=(beta1, beta2), + eps=eps, + nbits=8, + percentile_clipping=100, + block_wise=True, + min_8bit_size=4096, +) +``` + +## Plugin compatibility +
Model/Feature | +Lamb | +GaLore | +Adafactor | +CAME | +||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Hybrid Parallel Plugin |
+ ✔️ | +✔️ | +✔️ | +✔️ | +||||||||||||||||||||||||||||||||||
Low Level Zero Plugin |
+ ✔️ | +❌ | +✔️ | +✔️ | +||||||||||||||||||||||||||||||||||
Torch DDP Plugin |
+ ✔️ | +✔️ | +✔️ | +✔️ | +||||||||||||||||||||||||||||||||||
Gemini Plugin |
+ ❌ | +❌ | +❌ | +❌ | +||||||||||||||||||||||||||||||||||
Moe Hybrid Plugin |
+ ❌ | +❌ | +❌ | +❌ | +||||||||||||||||||||||||||||||||||
+ |
Model/Feature | +Lamb | +GaLore | +Adafactor | +CAME | +||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Hybrid Parallel Plugin |
+ ✔️ | +✔️ | +✔️ | +✔️ | +||||||||||||||||||||||||||||||||||
Low Level Zero Plugin |
+ ✔️ | +❌ | +✔️ | +✔️ | +||||||||||||||||||||||||||||||||||
Torch DDP Plugin |
+ ✔️ | +✔️ | +✔️ | +✔️ | +||||||||||||||||||||||||||||||||||
Gemini Plugin |
+ ❌ | +❌ | +❌ | +❌ | +||||||||||||||||||||||||||||||||||
Moe Hybrid Plugin |
+ ❌ | +❌ | +❌ | +❌ | +||||||||||||||||||||||||||||||||||
+ |