Merge pull request #5737 from yuanheng-zhao/inference/sync/main

[sync] Sync feature/colossal-infer with main
pull/5743/head
Yuanheng Zhao 2024-05-21 11:26:37 +08:00 committed by GitHub
commit c06208e72c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
61 changed files with 6976 additions and 276 deletions

View File

@ -8,6 +8,13 @@ body:
attributes:
value: >
#### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new).
- type: checkboxes
attributes:
label: Is there an existing issue for this bug?
description: Please search [here](https://github.com/hpcaitech/ColossalAI/issues) to see if an open or closed issue already exists for the bug you have encountered.
options:
- label: I have searched the existing issues
required: true
- type: textarea
attributes:
label: 🐛 Describe the bug

View File

@ -140,7 +140,7 @@ jobs:
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v -e .
pip install -v -e .
pip install -r requirements/requirements-test.txt
- name: Store Colossal-AI Cache

View File

@ -418,7 +418,7 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
## Installation
Requirements:
- PyTorch >= 1.11 and PyTorch <= 2.1
- PyTorch >= 2.1
- Python >= 3.7
- CUDA >= 11.0
- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher)

View File

@ -10,7 +10,7 @@ import math
import os
from multiprocessing import cpu_count
from colossal_llama.dataset.conversation import default_conversation
from colossal_llama.dataset.conversation import LLaMA2_Conv
from colossal_llama.dataset.spliced_and_tokenized_dataset import supervised_tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers import AddedToken, AutoTokenizer
@ -78,6 +78,7 @@ def main():
# Fix </s> split issue: https://github.com/huggingface/transformers/issues/23833
if args.llama_version == 2:
tokenizer.add_tokens(AddedToken("</s>", normalized=False, special=True), special_tokens=True)
default_conversation = LLaMA2_Conv
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False

View File

@ -1,7 +1,9 @@
import ctypes
import random
import warnings
from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
from types import MethodType
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
@ -24,6 +26,8 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
@ -735,7 +739,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads)
if self.require_grad_sync and grads_to_sync is not None:
if self._grad_store.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
else:
@ -759,7 +763,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
@ -784,7 +788,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
@ -1171,6 +1175,15 @@ class HybridParallelPlugin(PipelinePluginBase):
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
# TODO: Support Galore + ZeRO
zero_stage = self.zero_stage
zero_config = deepcopy(self.zero_config)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_config["partition_grad"] = False
zero_stage = 0
if not isinstance(model, ModelWrapper):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
@ -1194,7 +1207,8 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy=self.custom_policy,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if zero_stage == 0:
is_zero = False
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
@ -1218,11 +1232,11 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_process_group=self.tp_group,
)
else:
zero_dp_size = dist.get_world_size(dp_group)
if zero_dp_size == 1:
is_zero = self.dp_size > 1
if self.dp_size == 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
@ -1236,11 +1250,19 @@ class HybridParallelPlugin(PipelinePluginBase):
pp_process_group=self.pp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
**zero_config,
**self.amp_config,
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
# Setup optimizers that require global states
optim = optimizer.optim
if isinstance(optim, DistributedOptim):
shard_to_param = optimizer.get_master_to_working_map() if is_zero else {}
padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int)
optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero)
return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(
@ -1272,7 +1294,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False
):
return outputs

View File

@ -8,7 +8,10 @@ from types import MethodType
from typing import Callable, Dict, Iterator, List, Optional, Tuple
import torch
import torch.distributed
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.distributed_c10d import _get_default_group
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
@ -28,6 +31,8 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.zero import LowLevelZeroOptimizer
@ -428,13 +433,31 @@ class LowLevelZeroPlugin(DPPluginBase):
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
# TODO: Support Galore + ZeRO
zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs}
dp_size = dist.get_world_size()
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_optim_kwargs["partition_grad"] = False
zero_stage = 0
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
optimizer, **zero_optim_kwargs, verbose=self.verbose
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
# Setup optimizers that require global states
optim = optimizer.optim
is_zero = dp_size > 1 and zero_stage > 0
dp_group = _get_default_group() # Use the whole world
if isinstance(optim, DistributedOptim):
shard_to_param = optimizer.get_master_to_working_map()
padding_map = optimizer.get_param_padding_map()
optim.setup_distributed(None, dp_group, shard_to_param, padding_map, is_zero)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:

View File

@ -38,7 +38,12 @@ class ProcessGroupMesh:
def __init__(self, *size: int) -> None:
assert dist.is_initialized(), "Please initialize torch.distributed first."
assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size."
world_size = dist.get_world_size()
prod_size = prod(size)
assert (
prod_size == world_size
), f"The product of the size({prod_size}) must be equal to the world size({world_size})."
self._shape = size
self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)

View File

@ -306,9 +306,8 @@ class DeviceMesh:
# index means the local rank in the current axis
# inner_tensor refers to the processes with the same local rank
if inner_tensor.numel() == 1:
# if the inner_tensor only has one element, it means that
# it already reaches the last axis
if inner_tensor.dim() == 0:
# if the inner_tensor already reaches the last axis,
# we append its local_rank in the last axis to the index_list
# and assign to the mapping
# the value of the mapping is the the local rank at the indexed axis of the device mesh
@ -459,6 +458,7 @@ class DeviceMesh:
# replace the local rank in the given dimension with the
# local rank of the current process iterated
process_coordinates[dim] = _local_rank
processes_in_the_same_process_group[dim].append(process_coordinates)

View File

@ -1,6 +1,7 @@
from typing import Union
from typing import Dict, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
@ -133,3 +134,25 @@ class OptimizerWrapper:
Unwrap the optimizer for checkpoint saving/loading.
"""
return self.optim
class DistributedOptim(Optimizer):
def setup_distributed(
self,
tp_group: Optional[dist.ProcessGroup] = None,
dp_group: Optional[dist.ProcessGroup] = None,
shard_to_working_param: Optional[Dict] = {},
padding_map: Optional[Dict] = None,
is_zero: Optional[bool] = False,
):
"""Assign process groups for TP and ZeRO 2.
Arguments:
tp_group (dist.ProcessGroup): Tensor Parallel process group
dp_group (dist.ProcessGroup): ZeRO stage 2 process group
shard_to_working_param (Dict): ZeRO stage 2 feeds the optimizer a sharded param view to match grad shape.
This maps from id(view) to model params used in forward & backward.
padding_map (Dict): Per-param padding from ZeRO stage 2
is_zero (bool): Whether to use ZeRO stage 2.
"""
raise NotImplementedError("setup_distributed for TP/DP isn't supported by this optimizer yet!")

View File

@ -1,3 +1,4 @@
import copy
import os
from typing import Callable, Optional, Union
@ -74,6 +75,24 @@ def new_from_pretrained(
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
kwargs.pop("state_dict", None)
kwargs.pop("from_tf", False)
kwargs.pop("from_flax", False)
kwargs.pop("output_loading_info", False)
kwargs.pop("trust_remote_code", None)
kwargs.pop("low_cpu_mem_usage", None)
kwargs.pop("device_map", None)
kwargs.pop("max_memory", None)
kwargs.pop("offload_folder", None)
kwargs.pop("offload_state_dict", False)
kwargs.pop("load_in_8bit", False)
kwargs.pop("load_in_4bit", False)
kwargs.pop("quantization_config", None)
kwargs.pop("adapter_kwargs", {})
kwargs.pop("adapter_name", "default")
kwargs.pop("use_flash_attention_2", False)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
if len(kwargs) > 0:
@ -108,6 +127,10 @@ def new_from_pretrained(
**kwargs,
)
else:
config = copy.deepcopy(config)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs
if commit_hash is None:

View File

@ -55,14 +55,14 @@ class Async_DynamicBatchManager(DynamicBatchManager):
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
has_new_finished, outputs = self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens = 0
else:
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1
else:
@ -78,7 +78,7 @@ class Async_DynamicBatchManager(DynamicBatchManager):
else:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1
if has_new_finished:

View File

@ -131,14 +131,14 @@ class DynamicBatchManager:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
yield from self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens = 0
return
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1
return
else:
@ -154,7 +154,7 @@ class DynamicBatchManager:
else:
self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self._filter_running_batch()
self.has_wait_tokens += 1
return
@ -243,7 +243,7 @@ class DynamicBatchManager:
self._filter_batch(batch)
yield from self._output_process(finished_reqs)
def _filter_runing_batch(self):
def _filter_running_batch(self):
if self.running_batch is not None and self.running_batch.is_clear():
self.running_batch = None

View File

@ -1,9 +1,36 @@
from galore_torch import GaLoreAdafactor, GaLoreAdamW
from .came import CAME
from .cpu_adam import CPUAdam
from .distributed_adafactor import DistributedAdaFactor
from .distributed_came import DistributedCAME
from .distributed_galore import DistGaloreAwamW
from .distributed_lamb import DistributedLamb
from .fused_adam import FusedAdam
from .fused_lamb import FusedLAMB
from .fused_sgd import FusedSGD
from .galore import GaLoreAdamW8bit
from .hybrid_adam import HybridAdam
from .lamb import Lamb
from .lars import Lars
__all__ = ["FusedLAMB", "FusedAdam", "FusedSGD", "Lamb", "Lars", "CPUAdam", "HybridAdam"]
from .adafactor import Adafactor # noqa
__all__ = [
"FusedLAMB",
"FusedAdam",
"FusedSGD",
"Lamb",
"Lars",
"CPUAdam",
"HybridAdam",
"DistributedLamb",
"DistGaloreAwamW",
"GaLoreAdamW",
"GaLoreAdafactor",
"GaLoreAdamW8bit",
"CAME",
"DistributedCAME",
"Adafactor",
"DistributedAdaFactor",
]

View File

@ -0,0 +1,201 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch.optim import Optimizer
__all__ = ["Adafactor"]
# Adafactor
class Adafactor(Optimizer):
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=None,
weight_decay=0.0,
scale_parameter=True,
relative_step=True,
warmup_init=False,
):
lr = None
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
raise ValueError("`warmup_init=True` requires `relative_step=True`")
defaults = {
"lr": lr,
"eps": eps,
"clip_threshold": clip_threshold,
"decay_rate": decay_rate,
"beta1": beta1,
"weight_decay": weight_decay,
"scale_parameter": scale_parameter,
"relative_step": relative_step,
"warmup_init": warmup_init,
}
super().__init__(params, defaults)
@staticmethod
def _get_lr(param_group, param_state):
rel_step_sz = param_group["lr"]
if param_group["relative_step"]:
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
param_scale = 1.0
if param_group["scale_parameter"]:
param_scale = max(param_group["eps"][1], param_state["RMS"])
return param_scale * rel_step_sz
@staticmethod
def _get_options(param_group, param_shape):
factored = len(param_shape) >= 2
use_first_moment = param_group["beta1"] is not None
return factored, use_first_moment
@staticmethod
def _rms(tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
@staticmethod
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
"""
param_groups: Dict
{
"params":[weight, bias]
"lr"
"eps"
"clip_threshold"
"decay_rate"
"beta1"
"weight_decay"
"scale_parameter"
"relative_step"
"warmup_init"
}
"""
for group in self.param_groups:
# update weight & bias
for p in group["params"]:
if p.grad is None:
continue
"""
# grad shape is same as weigh / bias
"""
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")
"""
p is weight
state
{'step',
'exp_avg_sq_row',
'exp_avg_sq_col',
'RMS'
}
p is bias
state
{'step',
'exp_avg_sq',
'RMS'
}
"""
state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = self._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], device=grad.device)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], device=grad.device)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"]
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"]
state["exp_avg_sq_col"] = state["exp_avg_sq_col"]
else:
state["exp_avg_sq"] = state["exp_avg_sq"]
state["step"] += 1
# state["RMS"] = self._rms(p_data_fp32)
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
# Exponential average of row indexes
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
# Exponential average of columns indexes
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
# RMS
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
if group["weight_decay"] != 0:
p.add_(p, alpha=(-group["weight_decay"] * lr))
p.add_(-update)
return loss

View File

@ -0,0 +1,150 @@
# Copied from https://github.com/yangluo7/CAME/blob/master/came_pytorch/CAME.py
import torch
import torch.optim
class CAME(torch.optim.Optimizer):
"""Implements CAME algorithm.
This implementation is based on:
`CAME: Confidence-guided Adaptive Memory Efficient Optimization`
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): external learning rate (default: None)
eps (tuple[float, float]): regularization constants for square gradient
and instability respectively (default: (1e-30, 1e-16))
clip_threshold (float): threshold of root-mean-square of
final gradient update (default: 1.0)
betas (tuple[float, float, float]): coefficient used for computing running averages of
update, square gradient and instability (default: (0.9, 0.999, 0.9999)))
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
"""
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-16),
clip_threshold=1.0,
betas=(0.9, 0.999, 0.9999),
weight_decay=0.0,
):
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])
defaults = dict(
lr=lr,
eps=eps,
clip_threshold=clip_threshold,
betas=betas,
weight_decay=weight_decay,
)
super(CAME, self).__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return False
def _get_options(self, param_shape):
factored = len(param_shape) >= 2
return factored
def _rms(self, tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("CAME does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
factored = self._get_options(grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], dtype=p.dtype, device=p.device)
state["exp_avg_sq_col"] = torch.zeros(
grad_shape[:-2] + grad_shape[-1:], dtype=p.dtype, device=p.device
)
state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1], dtype=p.dtype, device=p.device)
state["exp_avg_res_col"] = torch.zeros(
grad_shape[:-2] + grad_shape[-1:], dtype=p.dtype, device=p.device
)
else:
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] += 1
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(group["betas"][1]).add_(update.mean(dim=-1), alpha=1.0 - group["betas"][1])
exp_avg_sq_col.mul_(group["betas"][1]).add_(update.mean(dim=-2), alpha=1.0 - group["betas"][1])
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1])
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
exp_avg = state["exp_avg"]
exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0])
# Confidence-guided strategy
# Calculation of instability
res = (update - exp_avg) ** 2 + group["eps"][1]
if factored:
exp_avg_res_row = state["exp_avg_res_row"]
exp_avg_res_col = state["exp_avg_res_col"]
exp_avg_res_row.mul_(group["betas"][2]).add_(res.mean(dim=-1), alpha=1.0 - group["betas"][2])
exp_avg_res_col.mul_(group["betas"][2]).add_(res.mean(dim=-2), alpha=1.0 - group["betas"][2])
# Approximation of exponential moving average of instability
res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col)
update = res_approx.mul_(exp_avg)
else:
update = exp_avg.clone()
if group["weight_decay"] != 0:
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
update.mul_(group["lr"])
p.data.add_(-update)
return loss

View File

@ -0,0 +1,440 @@
import math
from typing import Dict
import torch
import torch.distributed as dist
from colossalai.interface.optimizer import DistributedOptim
from colossalai.shardformer.layer._operation import _gather, _split
from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor
# DistributedAdaFactor (with Tensor parallel and Zero stage 2)
__all__ = ["DistributedAdaFactor"]
class DistributedAdaFactor(DistributedOptim):
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=None,
weight_decay=0.0,
scale_parameter=True,
relative_step=True,
warmup_init=False,
):
lr = None
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
raise ValueError("`warmup_init=True` requires `relative_step=True`")
defaults = {
"lr": lr,
"eps": eps,
"clip_threshold": clip_threshold,
"decay_rate": decay_rate,
"beta1": beta1,
"weight_decay": weight_decay,
"scale_parameter": scale_parameter,
"relative_step": relative_step,
"warmup_init": warmup_init,
}
self.tp_size = 1
self.tp_group = None
self.dp_size = 1
self.dp_group = None
self.shard_to_working_param = None # Dict{id:shape}, sample {id(param): torch.tensor}
self.use_zero = True
self.param_is_dtensor_dict = {} # {id(p): True/False}
self.grad_shape_dict = {} # {id(p): master param shape}
self.factored_dict = {} # {id(p): True/False}
self.use_first_moment_dict = {} # {id(p): True/False}
self.shard_spec_dict = {} # {id(p): ShardSpec}
super().__init__(params, defaults)
def setup_distributed(
self,
tp_group: dist.ProcessGroup = None,
dp_group: dist.ProcessGroup = None,
shard_to_working_param: Dict = {},
padding_map=None,
use_zero: bool = True,
) -> None:
"""Setup process groups for TP and ZeRO 2.
Inject features to the Optimizer
Args:
tp_group: The devices group for tensor parallel;
dp_group: The devices group for data parallel;
shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded.
This maps from id(view) to working params used in forward & backward.
padding_map: An empty interface placeholder;
use_zero: Whether or not to use zero;
"""
self.tp_group = tp_group # "Expected row process group"
self.dp_group = dp_group
if self.tp_group is not None:
self.tp_size = dist.get_world_size(self.tp_group)
if self.dp_group is not None:
self.dp_size = dist.get_world_size(self.dp_group)
self.use_zero = use_zero
self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {}
# grad is None, cause we dont setup now
for group in self.param_groups:
for p in group["params"]:
self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(
id(p), p
) # If not ZeRO, working param is master param
self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_working_param[id(p)])
self.grad_shape_dict[id(p)] = self.shard_to_working_param.get(id(p)).shape
self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(
group, self.grad_shape_dict[id(p)]
)
if self.param_is_dtensor_dict[id(p)]:
self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_working_param[id(p)])
else:
self.shard_spec_dict[id(p)] = None
@staticmethod
def _get_lr(param_group, param_state):
rel_step_sz = param_group["lr"]
if param_group["relative_step"]:
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
param_scale = 1.0
if param_group["scale_parameter"]:
param_scale = max(param_group["eps"][1], param_state["RMS"])
return param_scale * rel_step_sz
@staticmethod
def _get_options(param_group, param_shape):
"""
Determines whether the current param is factored
Args:
param_group : param group
param_shape : Original Shape of param
"""
factored = len(param_shape) >= 2
use_first_moment = param_group["beta1"] is not None
return factored, use_first_moment
@staticmethod
def _rms(tensor, param_is_dtensor, use_zero, tp_size, dp_size, tp_group, dp_group):
tensor_sum = tensor.pow(2).sum()
num_of_element = tensor.numel()
if param_is_dtensor:
# reduce tensor_sum from tp_group
dist.all_reduce(tensor_sum, group=tp_group)
num_of_element = num_of_element * tp_size
if use_zero:
dist.all_reduce(tensor_sum, group=dp_group)
num_of_element = num_of_element * dp_size
rms = (tensor_sum / num_of_element).sqrt()
return rms
@staticmethod
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
# approx_sq_grad for row parallel weight
@staticmethod
def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam):
# row_meam = sq_row_meam
r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t):
if grad_shape[0] % self.dp_size != 0:
# gather update[flatten] along dp group then reshape to [H, W/tp]
update = _gather(input_=update, dim=-1, process_group=self.dp_group)
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H, W/tp]
grad = _gather(input_=grad, dim=-1, process_group=self.dp_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
else:
update_reshape = update.view(-1, grad_shape[1])
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
dist.all_reduce(exp_avg_sq_row, group=self.tp_group)
exp_avg_sq_row.div_(self.tp_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = update_reshape.view(-1)
else:
update = update_reshape
return update
def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t):
if grad_shape[0] % self.dp_size != 0:
# gather update[flatten] along dp group then reshape to [H/tp, W]
update = _gather(input_=update, dim=-1, process_group=self.dp_group)
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H/tp, W]
grad = _gather(input_=grad, dim=-1, process_group=self.dp_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tp_group)
exp_avg_sq_col.div_(self.tp_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.dp_group)
else:
update = update_reshape
else:
update_reshape = update.view(-1, grad_shape[1])
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tp_group)
exp_avg_sq_col.div_(self.tp_size)
# gather row
exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tp_group)
sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True)
update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = update_reshape.view(-1)
else:
update = update_reshape
return update
def _base_factor(self, update, grad, state, grad_shape, beta2t):
if self.use_zero:
# only zero
if grad_shape[0] % self.dp_size != 0:
# view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1])
# row mean no change
# col mean need reduce and div
# gather update[flatten] along dp group then reshape to [H, W]
update = _gather(input_=update, dim=-1, process_group=self.dp_group)
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H, W]
grad = _gather(input_=grad, dim=-1, process_group=self.dp_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tp_group)
exp_avg_sq_col.div_(self.tp_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.dp_group)
else:
# no residual row
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W]
grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W]
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tp_group)
exp_avg_sq_col.div_(self.tp_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
update = update_reshape.view(-1)
else:
# base factor; no tp, no dp
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
# Exponential average of row indexes
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
# Exponential average of columns indexes
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
return update
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization steps
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
"""
param_groups: Dict
{
"params":[weight, bias]
"lr"
"eps"
"clip_threshold"
"decay_rate"
"beta1"
"weight_decay"
"scale_parameter"
"relative_step"
"warmup_init"
}
"""
for group in self.param_groups:
# update weight & bias
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
grad_shape = self.grad_shape_dict[id(p)]
param_is_dtensor = self.param_is_dtensor_dict[id(p)]
if param_is_dtensor:
grad_shape = self.shard_to_working_param.get(id(p)).shape # tp shape (2 dim)
factored, use_first_moment = self.factored_dict[id(p)], self.use_first_moment_dict[id(p)]
shard_spec = self.shard_spec_dict[id(p)]
if len(state) == 0:
state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p)
if factored:
if param_is_dtensor:
if shard_spec.sharding_sequence[0] == "R": # Col Parallel
if grad_shape[0] % self.dp_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H]
else:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype
) # [H/dp]
state["exp_avg_sq_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W/TP]
if shard_spec.sharding_sequence[-1] == "R": # Row Parallel
# Row indivisible shape situation
if grad_shape[0] % self.dp_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H/tp]
else:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype
) # [H/dp/tp]
state["exp_avg_sq_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W]
else:
if self.use_zero:
if grad_shape[0] % self.dp_size != 0:
# save all exp_avg_sq_row [H]
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=grad.device, dtype=p.dtype
)
else:
# exp_avg_sq_row [H // dp]
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.dp_size, device=grad.device, dtype=p.dtype
)
else:
# exp_avg_sq_row [H]
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype)
# exp_avg_sq_col alaways [W]
state["exp_avg_sq_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype)
else:
state["exp_avg_sq"] = torch.zeros_like(p)
state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"]
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"]
state["exp_avg_sq_col"] = state["exp_avg_sq_col"]
else:
state["exp_avg_sq"] = state["exp_avg_sq"]
state["step"] += 1
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad**2) + group["eps"][0]
if factored:
if param_is_dtensor:
# ==============================
# First Dim is R, Last Dim is S{} means split dim -1 --->
# Coloum Parallel ---> sq_row need Do (col) Reduce
# ==============================
if shard_spec.sharding_sequence[0] == "R":
update = self._col_parallel_factor(update, grad, state, grad_shape, beta2t)
# ==============================
# Last Dim is R, First Dim is S{} means split dim 0 --->
# Row Parallel ---> sq_col need Do (row) Reduce
# ==============================
elif shard_spec.sharding_sequence[-1] == "R":
update = self._row_parallel_factor(update, grad, state, grad_shape, beta2t)
else:
update = self._base_factor(update, grad, state, grad_shape, beta2t)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
# # (Line No.8) RMS
rms = self._rms(
update,
param_is_dtensor,
self.use_zero,
self.tp_size,
self.dp_size,
self.tp_group,
self.dp_group,
)
update.div_((rms / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
if group["weight_decay"] != 0:
p.add_(p, alpha=(-group["weight_decay"] * lr))
p.add_(-update)
return loss

View File

@ -0,0 +1,557 @@
from typing import Dict
import torch
import torch.distributed as dist
from colossalai.interface.optimizer import DistributedOptim
from colossalai.shardformer.layer._operation import _gather, _split
from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor
class DistributedCAME(DistributedOptim):
"""Implements CAME algorithm.
This implementation is based on:
`CAME: Confidence-guided Adaptive Memory Efficient Optimization`
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): external learning rate (default: None)
eps (tuple[float, float]): regularization constants for square gradient
and instability respectively (default: (1e-30, 1e-16))
clip_threshold (float): threshold of root-mean-square of
final gradient update (default: 1.0)
betas (tuple[float, float, float]): coefficient used for computing running averages of
update, square gradient and instability (default: (0.9, 0.999, 0.9999)))
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
"""
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-16),
clip_threshold=1.0,
betas=(0.9, 0.999, 0.9999),
weight_decay=0.0,
):
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])
defaults = dict(
lr=lr,
eps=eps,
clip_threshold=clip_threshold,
betas=betas,
weight_decay=weight_decay,
)
self.tp_size = 1
self.tp_group = None
self.dp_size = 1
self.dp_group = None
self.shard_to_working_param = None # Dict{id:shape}, sample {id(param): torch.tensor}
self.use_zero = True
self.param_is_dtensor_dict = {} # {id(p): True/False}
self.grad_shape_dict = {} # {id(p): master param shape}
self.factored_dict = {} # {id(p): True/False}
self.use_first_moment_dict = {} # {id(p): True/False}
self.shard_spec_dict = {} # {id(p): ShardSpec}
super(DistributedCAME, self).__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return False
def setup_distributed(
self,
tp_group: dist.ProcessGroup = None,
dp_group: dist.ProcessGroup = None,
shard_to_working_param: Dict = {},
padding_map=None,
use_zero: bool = True,
) -> None:
"""
Inject features to the Optimizer
Args:
tp_group: The devices group for tensor parallel;
dp_group: The devices group for data parallel;
shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded.
This maps from id(view) to working params used in forward & backward.
padding_map: Interface placeholder
use_zero: Whether or not to use zero;
"""
self.tp_group = tp_group # "Expected row process group"
self.dp_group = dp_group
if self.tp_group is not None:
self.tp_size = dist.get_world_size(self.tp_group)
if self.dp_group is not None:
self.dp_size = dist.get_world_size(self.dp_group)
self.use_zero = use_zero
self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {}
# grad is None, cause we dont setup now
for group in self.param_groups:
for p in group["params"]:
# w/o ZeRO: master param = working param
self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(id(p), p)
self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_working_param[id(p)])
self.grad_shape_dict[id(p)] = self.shard_to_working_param[id(p)].shape
# Avoid row parallel lead H=1, then factored param is determined as not factored;
if self.param_is_dtensor_dict[id(p)]:
self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_working_param[id(p)])
if self.shard_spec_dict[id(p)].sharding_sequence[0] == "R":
self.factored_dict[id(p)] = True
elif self.shard_spec_dict[id(p)].sharding_sequence[-1] == "R":
self.factored_dict[id(p)] = True
else:
self.factored_dict[id(p)] = self._get_options(self.grad_shape_dict[id(p)])
else:
self.shard_spec_dict[id(p)] = None
self.factored_dict[id(p)] = self._get_options(self.grad_shape_dict[id(p)])
@staticmethod
def _get_options(param_shape):
factored = len(param_shape) >= 2
return factored
@staticmethod
def _rms(tensor, param_is_dtensor, use_zero, tp_size, dp_size, tp_group, dp_group):
tensor_sum = tensor.pow(2).sum()
num_of_element = tensor.numel()
if param_is_dtensor:
# reduce tensor_sum from tp_group
dist.all_reduce(tensor_sum, group=tp_group)
num_of_element = num_of_element * tp_size
if use_zero:
dist.all_reduce(tensor_sum, group=dp_group)
num_of_element = num_of_element * dp_size
rms = (tensor_sum / num_of_element).sqrt()
return rms
@staticmethod
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
# approx_sq_grad for row parallel weight
@staticmethod
def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam):
r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
def _col_parallel_factor(self, update, grad, state_row, state_col, grad_shape, beta2t):
if grad_shape[0] % self.dp_size != 0:
# gather update[flatten] along dp group then reshape to [H, W/tp]
update = _gather(input_=update, dim=-1, process_group=self.dp_group)
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H, W/tp]
grad = _gather(input_=grad, dim=-1, process_group=self.dp_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state_row # [H]
exp_avg_sq_col = state_col # [W/tp]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
else:
update_reshape = update.view(-1, grad_shape[1])
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state_row # [H]
exp_avg_sq_col = state_col # [W/tp]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
dist.all_reduce(exp_avg_sq_row, group=self.tp_group)
exp_avg_sq_row.div_(self.tp_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = update_reshape.view(-1)
else:
update = update_reshape
return update
def _row_parallel_factor(self, update, grad, state_row, state_col, grad_shape, beta2t):
if grad_shape[0] % self.dp_size != 0:
# gather update[flatten] along dp group then reshape to [H/tp, W]
update = _gather(input_=update, dim=-1, process_group=self.dp_group)
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H/tp, W]
grad = _gather(input_=grad, dim=-1, process_group=self.dp_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state_row # [H]
exp_avg_sq_col = state_col # [W/tp]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tp_group)
exp_avg_sq_col.div_(self.tp_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.dp_group)
else:
update = update_reshape
else:
update_reshape = update.view(-1, grad_shape[1])
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state_row # [H]
exp_avg_sq_col = state_col # [W/tp]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tp_group)
exp_avg_sq_col.div_(self.tp_size)
# gather row
exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tp_group)
sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True)
update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = update_reshape.view(-1)
else:
update = update_reshape
return update
def _base_factor(self, update, grad, state_row, state_col, grad_shape, beta2t):
if self.use_zero:
# only zero
# [30522, 128], [2, 128]
if grad_shape[0] % self.dp_size != 0:
# view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1])
# row mean no change
# col mean need reduce and div
# gather update[flatten] along dp group then reshape to [H, W]
update = _gather(input_=update, dim=-1, process_group=self.dp_group)
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H, W]
grad = _gather(input_=grad, dim=-1, process_group=self.dp_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state_row # [H/dp]
exp_avg_sq_col = state_col # [W]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tp_group)
exp_avg_sq_col.div_(self.tp_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.dp_group)
else:
# no residual row
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W]
grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W]
exp_avg_sq_row = state_row # [H/dp]
exp_avg_sq_col = state_col # [W]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tp_group)
exp_avg_sq_col.div_(self.tp_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
update = update_reshape.view(-1)
else:
# # base factor; no tp, no dp
exp_avg_sq_row = state_row # [H/dp]
exp_avg_sq_col = state_col # [W]
# Exponential average of row indexes
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
# Exponential average of columns indexes
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
return update
# factor
def _base_res_factor(self, res, exp_avg, state_row, state_col, grad_shape, beta2t):
if self.use_zero:
# only zero
if grad_shape[0] % self.dp_size != 0:
# view res to origin shape res.view(grad_shape[0]//self.data_parallel_size , grad_shape[1])
# row mean no change
# col mean need reduce and div
# gather res[flatten] along dp group then reshape to [H, W]
res = _gather(input_=res, dim=-1, process_group=self.dp_group)
# view res to origin[tp] shape
res_reshape = res.view(-1, grad_shape[1])
# gather exp_avg[flatten] along dp group then reshape to [H, W]
exp_avg = _gather(input_=exp_avg, dim=-1, process_group=self.dp_group)
exp_avg_reshape = exp_avg.view(-1, grad_shape[1])
exp_avg_sq_row = state_row # [H/dp]
exp_avg_sq_col = state_col # [W]
exp_avg_sq_row.mul_(beta2t).add_(res_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(res_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tp_group)
exp_avg_sq_col.div_(self.tp_size)
res_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
res_reshape.mul_(exp_avg_reshape)
res = _split(input_=res_reshape.view(-1), dim=-1, process_group=self.dp_group)
else:
# no residual row
# view res to origin[tp] shape
res_reshape = res.view(-1, grad_shape[1]) # [H/dp, W]
exp_avg_reshape = exp_avg.view(-1, grad_shape[1]) # [H/dp, W]
exp_avg_sq_row = state_row # [H/dp]
exp_avg_sq_col = state_col # [W]
exp_avg_sq_row.mul_(beta2t).add_(res_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(res_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
# reduce col
dist.all_reduce(exp_avg_sq_col, group=self.tp_group)
exp_avg_sq_col.div_(self.tp_size)
res_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
res_reshape.mul_(exp_avg_reshape)
res = res_reshape.view(-1)
else:
# # base factor; no tp, no dp
exp_avg_sq_row = state_row # [H/dp]
exp_avg_sq_col = state_col # [W]
# Exponential average of row indexes
exp_avg_sq_row.mul_(beta2t).add_(res.mean(dim=-1), alpha=(1.0 - beta2t))
# Exponential average of columns indexes
exp_avg_sq_col.mul_(beta2t).add_(res.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
res = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
res.mul_(exp_avg)
return res
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("CAME does not support sparse gradients.")
state = self.state[p]
# Under zero the grad_shape is the original grad that is flattened and then cut (only one dimension)
grad_shape = grad.shape
grad_shape = self.grad_shape_dict[id(p)]
param_is_dtensor = self.param_is_dtensor_dict[id(p)]
if param_is_dtensor:
grad_shape = self.shard_to_working_param.get(id(p)).shape # tp shape (2 dim)
factored = self.factored_dict[id(p)]
shard_spec = self.shard_spec_dict[id(p)]
# State Initialization
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p)
if factored:
if param_is_dtensor:
if shard_spec.sharding_sequence[0] == "R": # Col Parallel
if grad_shape[0] % self.dp_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H]
state["exp_avg_res_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H]
else:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype
) # [H/dp]
state["exp_avg_res_row"] = torch.zeros(
grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype
) # [H/dp]
state["exp_avg_sq_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W/TP]
state["exp_avg_res_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W/TP]
if shard_spec.sharding_sequence[-1] == "R": # Row Parallel
# Row indivisible shape situation
if grad_shape[0] % self.dp_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H/tp]
state["exp_avg_res_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H/tp]
else:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype
) # [H/dp/tp]
state["exp_avg_res_row"] = torch.zeros(
grad_shape[0] // self.dp_size, device=p.device, dtype=p.dtype
) # [H/dp/tp]
state["exp_avg_sq_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W]
state["exp_avg_res_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W]
else:
if self.use_zero:
if grad_shape[0] % self.dp_size != 0:
# save all exp_avg_sq_row [H]
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=grad.device, dtype=p.dtype
)
state["exp_avg_res_row"] = torch.zeros(
grad_shape[0], device=grad.device, dtype=p.dtype
)
else:
# exp_avg_sq_row [H // dp]
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.dp_size, device=grad.device, dtype=p.dtype
)
state["exp_avg_res_row"] = torch.zeros(
grad_shape[0] // self.dp_size, device=grad.device, dtype=p.dtype
)
else:
# exp_avg_sq_row [H]
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype)
state["exp_avg_res_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype)
# exp_avg_sq_col alaways [W]
state["exp_avg_sq_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype)
state["exp_avg_res_col"] = torch.zeros(grad_shape[1], device=grad.device, dtype=p.dtype)
else:
state["exp_avg_sq"] = torch.zeros_like(p)
state["RMS"] = 0
else:
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"]
state["exp_avg_sq_col"] = state["exp_avg_sq_col"]
state["exp_avg_res_row"] = state["exp_avg_sq_row"]
state["exp_avg_res_col"] = state["exp_avg_sq_col"]
else:
state["exp_avg_sq"] = state["exp_avg_sq"]
state["step"] += 1
update = (grad**2) + group["eps"][0]
if factored:
if param_is_dtensor:
# ==============================
# First Dim is R, Last Dim is S{} means split dim -1 --->
# Coloum Parallel ---> sq_row need Do (col) Reduce
# ==============================
if shard_spec.sharding_sequence[0] == "R":
update = self._col_parallel_factor(
update,
grad,
state["exp_avg_sq_row"],
state["exp_avg_sq_col"],
grad_shape,
group["betas"][1],
)
# ==============================
# Last Dim is R, First Dim is S{} means split dim 0 --->
# Row Parallel ---> sq_col need Do (row) Reduce
# ==============================
elif shard_spec.sharding_sequence[-1] == "R":
update = self._row_parallel_factor(
update,
grad,
state["exp_avg_sq_row"],
state["exp_avg_sq_col"],
grad_shape,
group["betas"][1],
)
else:
update = self._base_factor(
update,
grad,
state["exp_avg_sq_row"],
state["exp_avg_sq_col"],
grad_shape,
group["betas"][1],
)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=(1.0 - group["betas"][1]))
update = exp_avg_sq.rsqrt().mul_(grad)
rms = self._rms(
update,
param_is_dtensor,
self.use_zero,
self.tp_size,
self.dp_size,
self.tp_group,
self.dp_group,
)
update.div_((rms / group["clip_threshold"]).clamp_(min=1.0))
exp_avg = state["exp_avg"]
exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0])
# Confidence-guided strategy
# Calculation of instability
res = (update - exp_avg) ** 2 + group["eps"][1]
if factored:
if param_is_dtensor:
# ==============================
# First Dim is R, Last Dim is S{} means split dim -1 --->
# Coloum Parallel ---> sq_row need Do (col) Reduce
# ==============================
if shard_spec.sharding_sequence[0] == "R":
update = self._col_parallel_factor(
res,
exp_avg,
state["exp_avg_res_row"],
state["exp_avg_res_col"],
grad_shape,
group["betas"][2],
)
# ==============================
# Last Dim is R, First Dim is S{} means split dim 0 --->
# Row Parallel ---> sq_col need Do (row) Reduce
# ==============================
elif shard_spec.sharding_sequence[-1] == "R":
update = self._row_parallel_factor(
res,
exp_avg,
state["exp_avg_res_row"],
state["exp_avg_res_col"],
grad_shape,
group["betas"][2],
)
else:
update = self._base_res_factor(
res,
exp_avg,
state["exp_avg_res_row"],
state["exp_avg_res_col"],
grad_shape,
group["betas"][2],
)
else:
update = exp_avg
if group["weight_decay"] != 0:
p.add_(p, alpha=-group["weight_decay"] * group["lr"])
update.mul_(group["lr"])
p.add_(-update)
return loss

View File

@ -0,0 +1,279 @@
""" adapted from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/adamw8bit.py"""
import warnings
from collections import defaultdict
from typing import Dict, Optional
import torch
import torch.distributed as dist
import torch.nn.functional as F
from bitsandbytes.optim.optimizer import Optimizer2State
from colossalai.interface.optimizer import DistributedOptim
from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor
from .galore import GaLoreProjector, make_low_rank_buffer
__all__ = ["DistributedGalore"]
# Mark sharded dimension
class DistGaloreAwamW(DistributedOptim, Optimizer2State):
r"""Implements Galore, a optimizer-agonistic gradient compression technique on 8-bit AdamW.
It largely compresses gradient via low-rank projection and is claimed to be insensitive to hyperparams like lr.
Supports Tensor Parallel and ZeRO stage 1 and 2 via booster and plugin.
Proposed in `GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection`
https://arxiv.org/abs/2403.03507
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-6)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)
nbits: Number of bits for quantization optim states. Only 32 and 8 are supported.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
nbits=8,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
nbits,
None,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
self.tp_size = 1
self.dp_size = 1
self.is_dist = {}
proj_none = all(["rank" not in group for group in self.param_groups])
if proj_none:
warnings.warn(
"Will not apply GaLore as rank isn't in any param group. If you forgot to, try get_galore_param_groups"
)
# Default from the paper
for group in self.param_groups:
if "rank" in group:
group["update_proj_gap"] = group.get("update_proj_gap", 200)
group["proj_type"] = group.get("proj_type", "std")
group["scale"] = group.get("scale", 0.25)
def setup_distributed(
self,
tp_group: Optional[dist.ProcessGroup] = None,
dp_group: Optional[dist.ProcessGroup] = None,
shard_to_working_param: Optional[Dict] = {},
padding_map: Optional[Dict] = defaultdict(int),
is_zero: Optional[bool] = False,
):
"""Setup process groups for TP and ZeRO 2.
Arguments:
tp_group (dist.ProcessGroup): Tensor Parallel process group
dp_group (dist.ProcessGroup): ZeRO 2 process group
shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded.
This maps from id(view) to working params used in forward & backward.
padding_map (Dict): Padding size of each param from ZeRO's param store. Required if ZeRO is used.
is_zero (bool): Whether to use ZeRO 2.
"""
assert dist.is_initialized(), "You forgot to initialized distributed backend..."
self.tp_group = tp_group
self.dp_group = dp_group
if tp_group is not None:
self.tp_size = dist.get_world_size(tp_group)
if dp_group is not None:
self.dp_size = dist.get_world_size(dp_group)
self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {}
self.is_zero = is_zero and self.dp_size > 1
self.padding_map = padding_map if padding_map is not None else defaultdict(int)
if is_zero:
assert self.padding_map is not defaultdict(
int
), "We can't do SVD without knowing ZeRO's per-param padding size"
self.distributed_on = self.tp_size > 0 or self.dp_size > 0
# Cache working param layout
self.shard_dim = {}
for group in self.param_groups:
for p in group["params"]:
# w/o ZeRO: master param = working param
self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(id(p), p)
if id(p) not in self.padding_map:
self.padding_map[id(p)] = 0
self.is_dist[id(p)] = is_distributed_tensor(self.shard_to_working_param[id(p)])
if is_distributed_tensor(self.shard_to_working_param[id(p)]):
self.shard_dim[id(p)] = get_shard_dim_1d(self.shard_to_working_param[id(p)])
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
if not self.initialized:
self.check_overrides()
self.to_gpu()
self.initialized = True
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
if "step" not in state:
state["step"] = 0
# GaLore Projection
if "rank" in group:
if "projector" not in state:
state["projector"] = GaLoreProjector(
group["rank"],
scale=group["scale"],
update_proj_gap=group["update_proj_gap"],
proj_type=group["proj_type"],
)
# decoupled weight decay
if "weight_decay" in group and group["weight_decay"] > 0:
group["weight_decay_saved"] = group["weight_decay"]
group["weight_decay"] = 0
grad = p.grad
working_shape = list(self.shard_to_working_param[id(p)].shape)
padding = self.padding_map[id(p)]
# All-gather grads for projection step
if self.distributed_on:
# Gather for ZeRO 1 & 2 implementation don't retain full grads
if self.is_zero:
# (m, n).flatten().chunk(dp_size) equals to (m / dp_size, n).flatten()
working_shape[0] //= self.dp_size
# Gather grads for projection
if state["step"] % group["update_proj_gap"] == 0:
all_grads = [
torch.empty_like(grad, dtype=p.grad.dtype, device=p.grad.device)
for _ in range(self.dp_size)
]
dist.all_gather(all_grads, grad, self.dp_group)
grad = torch.cat(all_grads)
# To working param shape
if padding > 0:
grad = grad[:-padding]
working_shape[0] *= self.dp_size
grad = grad.reshape(working_shape) # unflatten
# Gather TP grads
if self.is_dist[id(p)] and state["step"] % group["update_proj_gap"] == 0:
all_grads = [
torch.empty_like(grad, dtype=p.grad.dtype, device=p.grad.device)
for _ in range(self.tp_size)
]
dist.all_gather(all_grads, grad.contiguous(), self.tp_group)
grad = torch.cat(all_grads, dim=self.shard_dim[id(p)])
# Compute SVD. Will use a subset of singular vectors when grads are sharded.
grad = state["projector"].project(grad, state["step"])
# Re-shard gathered grads after SVD
if self.distributed_on and state["step"] % group["update_proj_gap"] == 0:
# TP
if self.is_dist[id(p)]:
grad = grad.chunk(self.tp_size, dim=self.shard_dim[id(p)])[dist.get_rank(self.tp_group)]
# ZeRO
# TODO: this might not work with padding, e.g. (3, 3) with dp size 2
# Need extra logic in ZeRO to pad nRows/nCols to be divisible by dp_size
if self.is_zero:
grad = grad.chunk(self.dp_size)[dist.get_rank(self.dp_group)]
grad = grad.contiguous() # avoid bitsandbytes update error
working_shape = grad.shape
# To flattended master param shape
grad = self.to_master_shape(grad, padding)
make_low_rank_buffer(p, grad)
if "state1" not in state:
self.init_state(group, p, gindex, pindex)
self.prefetch_state(p)
self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize()
# Project Back to working param shape
if "rank" in group:
# Unpad
if self.is_zero:
if padding > 0:
p.data = p.data[:-padding]
p.data = p.data.reshape(working_shape)
p.data = state["projector"].project_back(p.data)
# Re-flatten grads for ZeRO
p.data = self.to_master_shape(p.data, padding)
p.data = p.saved_data.add_(p.data)
# apply decoupled weight decay
if "weight_decay_saved" in group:
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay_saved"])
group["weight_decay"] = group["weight_decay_saved"]
del group["weight_decay_saved"]
if self.is_paged:
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
return loss
def to_master_shape(self, data, padding):
"""Pad to master (optimizer) param shape"""
if not self.is_zero:
return data
data = data.view(-1)
if padding > 0:
data = F.pad(data, [0, padding])
return data
def __del__(self):
"""Avoid buffer memory leak"""
for group in self.param_groups:
for p in group["params"]:
if hasattr(p, "saved_data"):
del p.saved_data

View File

@ -0,0 +1,181 @@
# Disclaimer: Modified from https://github.com/NUS-HPC-AI-Lab/pytorch-lamb/blob/master/optim/lamb.py
from typing import Dict, Optional
import torch
import torch.distributed as dist
from colossalai.interface.optimizer import DistributedOptim
from colossalai.tensor.d_tensor import is_distributed_tensor
__all__ = ["DistributedLamb"]
class DistributedLamb(DistributedOptim):
r"""Implements the Lamb algorithm, with extra support for ZeRO 2 and Tensor Parallel.
Proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
It's recommended to use this with HybridParallelPlugin/ZeRO plugin and booster,
which will take care of setup_distributed.
Example with 4 devices:
>>> optim = DistributedLamb(model.parameters(), lr=1e-3)
>>> proc_mesh = ProcessGroupMesh(tp_size, zero_size)
>>> tp_group = proc_mesh.get_group_along_axis(0)
>>> dp_group = proc_mesh.get_group_along_axis(1)
>>> optim.setup_distributed(tp_group, dp_group)
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0,
bias_correction=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
# self.setup_distributed(tp_group, dp_group)
self.shard_to_working_param = {}
self.tp_size = self.dp_size = 1
self.is_zero = False
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super().__init__(params, defaults)
def setup_distributed(
self,
tp_group: Optional[dist.ProcessGroup] = None,
dp_group: Optional[dist.ProcessGroup] = None,
shard_to_working_param: Optional[Dict] = {},
padding_map=None,
is_zero: Optional[bool] = False,
):
"""Assign process groups for TP and ZeRO 2.
Arguments:
tp_group (dist.ProcessGroup): Tensor Parallel process group
dp_group (dist.ProcessGroup): ZeRO 2 process group
shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded.
This maps from id(view) to working params used in forward & backward.
padding_map: An empty interface placeholder
is_zero (bool): Whether to use ZeRO 2.
"""
self.tp_group = tp_group
self.dp_group = dp_group
if tp_group is not None:
self.tp_size = dist.get_world_size(tp_group)
if dp_group is not None:
self.dp_size = dist.get_world_size(dp_group)
self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {}
self.is_zero = is_zero
self.is_dist = {}
# Cache parameter layout
for group in self.param_groups:
for p in group["params"]:
# w/o ZeRO: master param = working param
self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(id(p), p)
self.is_dist[p] = (
is_distributed_tensor(p)
if self.dp_size <= 1
else is_distributed_tensor(self.shard_to_working_param.get(id(p), None))
)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instad.")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
scaled_lr = group["lr"]
if group["bias_correction"]:
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# Apply debiasing to lr to avoid broadcast
scaled_lr *= (bias_correction2**0.5) / bias_correction1
# exp_avg.div_(bias_correction1)
# exp_avg_sq.div_(bias_correction2)
update = exp_avg / exp_avg_sq.sqrt().add(group["eps"])
if group["weight_decay"] != 0:
update.add_(p.data, alpha=group["weight_decay"])
# Compute global layer-wise trust ratio
if self.is_dist[p] or self.is_zero:
p_local = p
g_sum = (update**2).sum()
if self.dp_size > 1 and self.is_zero:
# ZeRO 2 doesn't shard param. Compute full param norm w/o communication.
dist.all_reduce(g_sum, group=self.dp_group)
p_local = self.shard_to_working_param[id(p)]
w_sum = (p_local**2).sum()
sums = torch.stack([w_sum, g_sum])
# Get global l2 norms
if self.tp_size > 1:
dist.all_reduce(sums, group=self.tp_group)
w_norm, g_norm = sums.sqrt().chunk(2)
else:
# Fall back to vanilla Lamb
w_norm = torch.norm(p)
g_norm = torch.norm(update)
trust_ratio = torch.where(w_norm > 0 and g_norm > 0, (w_norm / g_norm), 1.0).item()
scaled_lr *= trust_ratio
p.data.add_(update, alpha=-scaled_lr)
return loss

View File

@ -0,0 +1,315 @@
""" adapted from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/adamw8bit.py"""
import warnings
from typing import List
import torch
from bitsandbytes.optim.optimizer import Optimizer2State
from torch._C import _LinAlgError
def get_galore_param_groups(
model, weight_decay, rank=256, update_proj_gap=200, scale=0.25, proj_type="std"
) -> List[dict]:
"""
It's advised to use this instead of manually specifying which param groups
to apply GaLore on.
"""
galore_params = []
non_galore = []
no_decay_params = []
no_decay = ["bias", "LayerNorm.weight"]
for name, param in model.named_parameters():
# Only make sense to do SVD on 2d gradient matrices
# e.g. nn.Linear, VocabEmbedding, etc.
if any(nd in name for nd in no_decay):
no_decay_params.append(param)
elif param.dim() == 2:
galore_params.append(param)
else:
non_galore.append(param)
param_groups = [
{
"params": galore_params,
"rank": rank,
"update_proj_gap": update_proj_gap,
"scale": scale,
"proj_type": proj_type,
"weight_decay": weight_decay,
},
{"params": non_galore, "weight_decay": weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
return param_groups
def make_low_rank_buffer(p, grad):
"""For compatibility with bitsandbytes's update_step, we need an empty low-rank
param update buffer to avoid mutating original params.
TODO: optimize by reusing the memory for p.grad? Need to modify bitsandbytes?
"""
p.saved_data = p.data.clone()
# p.data = grad.clone().to(p.data.dtype).to(p.data.device)
p.data = torch.zeros_like(grad, device=grad.device, dtype=grad.dtype)
# p.data.zero_()
p.grad = grad
class GaLoreProjector:
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type="std"):
self.rank = rank
self.verbose = verbose
self.update_proj_gap = update_proj_gap
self.scale = scale
self.ortho_matrix = None
self.proj_type = proj_type
self.svd_type = None
def project(self, full_rank_grad, iter):
dim = full_rank_grad.dim()
if dim != 2:
warnings.warn(
f"Warning: You shouldn't specify projection rank for {dim}D params in param_groups. Skipping SVD."
)
return full_rank_grad
m, n = full_rank_grad.shape # For ZeRO sharded grads
if self.proj_type == "std":
# Project the lower dim to minimize information loss
if self.svd_type is None:
self.svd_type = "right" if m >= n else "left"
# SVD step
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type=self.svd_type)
if self.svd_type == "right":
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()[:n])
else:
low_rank_grad = torch.matmul(self.ortho_matrix.t()[:, :m], full_rank_grad)
elif self.proj_type == "reverse_std":
if self.svd_type is None:
self.svd_type = "left" if m >= n else "right"
# SVD step
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type=self.svd_type)
if self.svd_type == "left":
low_rank_grad = torch.matmul(self.ortho_matrix.t()[:, :m], full_rank_grad)
else:
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()[:n])
return low_rank_grad
def project_back(self, low_rank_grad):
if low_rank_grad.dim() != 2:
return
m, n = low_rank_grad.shape
if self.svd_type == "right":
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix[:n])
else:
full_rank_grad = torch.matmul(self.ortho_matrix[:, :m], low_rank_grad)
return full_rank_grad * self.scale
# svd decomposition
def get_orthogonal_matrix(self, weights, rank, type):
module_params = weights
if module_params.data.dtype != torch.float:
float_data = False
original_type = module_params.data.dtype
original_device = module_params.data.device
matrix = module_params.data.float()
else:
float_data = True
matrix = module_params.data
# TODO: redo SVD in the next step.
if matrix.isnan().any():
print(f"{__file__}: skipping SVD due to NaN matrix")
return self.ortho_matrix
try:
U, s, Vh = torch.linalg.svd(matrix, full_matrices=False)
except _LinAlgError as e:
print(f"{__file__}: skipping SVD due to {e}")
return self.ortho_matrix
# make the smaller matrix always to be orthogonal matrix
if type == "right":
B = Vh[:rank, :]
if not float_data:
B = B.to(original_device).type(original_type)
return B
elif type == "left":
A = U[:, :rank]
if not float_data:
A = A.to(original_device).type(original_type)
return A
elif type == "full":
A = U[:, :rank]
B = Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
B = B.to(original_device).type(original_type)
return [A, B]
else:
raise ValueError("type should be left, right or full")
class GaLoreAdamW8bit(Optimizer2State):
r"""Implements Galore, a optimizer-agonistic gradient compression technique on 8-bit AdamW.
Proposed in `GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection`. It compresses
gradient via low-rank projection and is claimed to be insensitive to hyperparams like lr.
https://arxiv.org/abs/2403.03507
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-6)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)
nbits (int): The number of bits of optim states. Only 32 and 8 are supported.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
Example:
"""
def __init__(
self,
params,
lr=1e-2,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
nbits=8,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
nbits,
None,
min_8bit_size,
percentile_clipping,
block_wise,
is_paged=is_paged,
)
proj_none = all(["rank" not in group for group in self.param_groups])
if proj_none:
warnings.warn(
"Will not apply GaLore as no rank is specified. Or did you forget to? Try get_galore_param_groups"
)
# Defaults from the paper
for group in self.param_groups:
if "rank" in group:
group["update_proj_gap"] = group.get("update_proj_gap", 200)
group["proj_type"] = group.get("proj_type", "std")
group["scale"] = group.get("scale", 0.25)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
if not self.initialized:
self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
if p is self.param_groups[0]["params"][0]:
torch.save(p.grad, "grad.pt")
state = self.state[p]
if "step" not in state:
state["step"] = 0
# GaLore Projection
if "rank" in group:
if "projector" not in state:
state["projector"] = GaLoreProjector(
group["rank"],
scale=group["scale"],
update_proj_gap=group["update_proj_gap"],
proj_type=group["proj_type"],
)
if "weight_decay" in group and group["weight_decay"] > 0:
# ensure that the weight decay is not applied to the norm grad
group["weight_decay_saved"] = group["weight_decay"]
group["weight_decay"] = 0
grad = state["projector"].project(p.grad, state["step"])
make_low_rank_buffer(p, grad)
if "state1" not in state:
self.init_state(group, p, gindex, pindex)
# p.grad = p.grad.contiguous() # avoid bitsandbytes update error
# Prefetch if paged
self.prefetch_state(p)
# Adam update step using the buffer
self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize()
# GaLore Projection Back
if "rank" in group:
if p is self.param_groups[0]["params"][1]:
pass
update = state["projector"].project_back(p.data)
p.data = p.saved_data.add_(update)
# apply weight decay
if "weight_decay_saved" in group:
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay_saved"])
group["weight_decay"] = group["weight_decay_saved"]
del group["weight_decay_saved"]
if self.is_paged:
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
return loss
def __del__(self):
"""Avoid buffer memory leak"""
for group in self.param_groups:
for p in group["params"]:
if hasattr(p, "saved_data"):
del p.saved_data

View File

@ -26,7 +26,9 @@ class Lamb(Optimizer):
https://arxiv.org/abs/1904.00962
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False):
def __init__(
self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False, bias_correction=False
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@ -35,7 +37,7 @@ class Lamb(Optimizer):
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
self.adam = adam
super(Lamb, self).__init__(params, defaults)
@ -79,12 +81,15 @@ class Lamb(Optimizer):
# v_t
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Paper v3 does not use debiasing.
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
# Apply bias to lr to avoid broadcast.
# * math.sqrt(bias_correction2) / bias_correction1
step_size = group["lr"]
# NOTE: Paper v3 does not use debiasing.
scaled_lr = group["lr"]
if group["bias_correction"]:
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# Apply debiasing to lr to avoid broadcast
scaled_lr *= (bias_correction2**0.5) / bias_correction1
# exp_avg.div_(bias_correction1)
# exp_avg_sq.div_(bias_correction2)
weight_norm = p.data.pow(2).sum().sqrt()
@ -97,12 +102,10 @@ class Lamb(Optimizer):
trust_ratio = 1
else:
trust_ratio = weight_norm / adam_norm
state["weight_norm"] = weight_norm
state["adam_norm"] = adam_norm
state["trust_ratio"] = trust_ratio
if self.adam:
trust_ratio = 1
p.data.add_(adam_step, alpha=-step_size * trust_ratio)
p.data.add_(adam_step, alpha=-scaled_lr * trust_ratio)
return loss

View File

@ -16,7 +16,7 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention
from ..layer import ColoAttention, cross_entropy_1d
logger = logging.get_logger(__name__)
@ -270,11 +270,21 @@ class MistralForwards:
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
@ -609,3 +619,100 @@ def get_mistral_flash_attention_forward(shard_config: ShardConfig):
return attn_output, None, past_key_value
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import MistralForCausalLM
def forward(
self: MistralForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MistralForCausalLM
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return forward

View File

@ -22,6 +22,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
logger = logging.get_logger(__name__)
@ -336,8 +338,22 @@ class OPTPipelineForwards:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
else:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
@ -844,3 +860,146 @@ def get_jit_fused_opt_decoder_layer_forward():
return outputs
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
def forward(
self: OPTForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, OPTForCausalLM
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0]).contiguous()
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return forward

View File

@ -0,0 +1,758 @@
from typing import List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
repeat_kv,
)
except ImportError:
Qwen2Model = "Qwen2Model"
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2Attention = "Qwen2Attention"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d
class Qwen2PipelineForwards:
"""
This class serves as a micro library for forward function substitution of Qwen2 models
under pipeline setting.
"""
@staticmethod
def qwen2_model_forward(
self: Qwen2Model,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
seq_length_with_past = seq_length
past_key_values_length = 0
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
# assert past_key_values is None, "past_key_values is not supported for Qwen2 models at the moment."
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1]
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# always return dict for imediate stage
return {"hidden_states": hidden_states}
@staticmethod
def qwen2_for_causal_lm_forward(
self: Qwen2ForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = Qwen2PipelineForwards.qwen2_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def qwen2_for_sequence_classification_forward(
self: Qwen2ForSequenceClassification,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
transformer_outputs = Qwen2PipelineForwards.qwen2_model_forward(
self.model,
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
if input_ids is not None:
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
batch_size = inputs_embeds.shape[0]
else:
batch_size = hidden_states.shape[0]
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if self.config.pad_token_id is None and batch_size != 1:
print(self.config.pad_token_id)
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get("hidden_states")
return {"hidden_states": hidden_states}
def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
def forward(
self: Qwen2Attention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return forward
def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig):
logger = logging.get_logger(__name__)
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
def forward(
self: Qwen2ForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return forward

View File

@ -182,6 +182,16 @@ _POLICY_LIST = {
"transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation(
file_name="mistral", class_name="MistralForSequenceClassificationPolicy"
),
# Qwen2
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
file_name="qwen2", class_name="Qwen2ModelPolicy"
),
"transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM": PolicyLocation(
file_name="qwen2", class_name="Qwen2ForCausalLMPolicy"
),
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
),
}

View File

@ -18,6 +18,7 @@ from colossalai.shardformer.layer import (
from ..modeling.mistral import (
MistralForwards,
get_lm_forward_with_dist_cross_entropy,
get_mistral_flash_attention_forward,
get_mistral_model_forward_for_flash_attn,
)
@ -275,14 +276,18 @@ class MistralForCausalLMPolicy(MistralPolicy):
SubModuleReplacementDescription(
suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs=dict(
gather_output=True,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
),
kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
)
]
)
}
if self.shard_config.parallel_output:
new_item[MistralForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
else:
new_item = {
MistralForCausalLM: ModulePolicyDescription(

View File

@ -21,6 +21,7 @@ from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.opt import (
OPTPipelineForwards,
get_jit_fused_opt_decoder_layer_forward,
get_lm_forward_with_dist_cross_entropy,
get_opt_decoder_forward_for_flash_attention,
get_opt_flash_attention_forward,
)
@ -269,12 +270,18 @@ class OPTForCausalLMPolicy(OPTPolicy):
suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs=dict(
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
),
),
policy=policy,
target_key=OPTForCausalLM,
)
if self.shard_config.parallel_output:
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=OPTForCausalLM
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(

View File

@ -0,0 +1,374 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
RMSNorm,
VocabParallelEmbedding1D,
)
from ..modeling.qwen2 import (
Qwen2PipelineForwards,
get_lm_forward_with_dist_cross_entropy,
get_qwen2_flash_attention_forward,
get_qwen2_model_forward_for_flash_attn,
)
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2FlashAttention2,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
Qwen2SdpaAttention,
)
except ImportError:
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
Qwen2Attention = "Qwen2Attention"
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
Qwen2SdpaAttention = "Qwen2SdpaAttention"
Qwen2DecoderLayer = "Qwen2DecoderLayer"
Qwen2Model = "Qwen2Model"
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
class Qwen2Policy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
from packaging.version import Version
assert Version(transformers.__version__) >= Version(
"4.39.1"
), "The Qwen2 model should run on a transformers version of 4.39.1."
def config_sanity_check(self):
pass
def preprocess(self):
self.tie_weight = self.tie_weight_check()
self.origin_attn_implement = self.model.config._attn_implementation
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
ATTN_IMPLEMENTATION = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2FlashAttention2,
"sdpa": Qwen2SdpaAttention,
}
policy = {}
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
norm_cls = FusedRMSNorm if self.shard_config.enable_fused_normalization else RMSNorm
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"):
assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
)
policy[Qwen2DecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=Qwen2Model,
)
# optimization configuration
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls,
),
],
policy=policy,
target_key=Qwen2DecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=norm_cls,
),
policy=policy,
target_key=Qwen2Model,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_qwen2_flash_attention_forward(self.shard_config),
},
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
# replace qwen2 model forward method
self.append_or_create_method_replacement(
description={
"forward": get_qwen2_model_forward_for_flash_attn(self.shard_config),
},
policy=policy,
target_key=Qwen2Model,
)
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager is None:
return
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "Qwen2Model":
module = self.model
else:
module = self.model.model
if stage_manager.is_interleave:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "Qwen2Model":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.norm)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
class Qwen2ModelPolicy(Qwen2Policy):
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=Qwen2Model, new_forward=Qwen2PipelineForwards.qwen2_model_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in Qwen2 model"""
return []
class Qwen2ForCausalLMPolicy(Qwen2Policy):
def module_policy(self):
policy = super().module_policy()
setattr(self.shard_config, "causal_lm", True)
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
Qwen2ForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=Qwen2ForCausalLM, new_forward=Qwen2PipelineForwards.qwen2_for_causal_lm_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
qwen2_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(qwen2_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: qwen2_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {
Qwen2ForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
)
]
)
}
policy.update(new_item)
# to be confirmed
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=Qwen2ForSequenceClassification,
new_forward=Qwen2PipelineForwards.qwen2_for_sequence_classification_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in Qwen2 for sequence classification model"""
return []

View File

@ -6,6 +6,7 @@ from .api import (
get_device_mesh,
get_global_shape,
get_layout,
get_shard_dim_1d,
get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
@ -37,6 +38,7 @@ __all__ = [
"get_device_mesh",
"redistribute",
"get_layout",
"get_shard_dim_1d",
"is_customized_distributed_tensor",
"distribute_tensor_with_customization",
"init_tensor_as_customization_distributed",

View File

@ -8,6 +8,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
from .layout import Layout
from .layout_converter import LayoutConverter
@ -15,6 +16,22 @@ from .sharding_spec import ShardingSpec
layout_converter = LayoutConverter()
_SHARD_DIM = DimSpec([0])
def get_shard_dim_1d(p: torch.Tensor):
"""
Get the dimension along which the tensor is sharded, for example in 1D Tensor Parallel.
Args:
p (torch.Tensor): the input tensor
Returns:
int: the dimension along which the tensor is sharded
"""
if not is_distributed_tensor(p):
raise ValueError("p is not a distributed tensor")
sharding = p.dist_layout.sharding_spec.sharding_sequence
return sharding.index(_SHARD_DIM)
def clear_layout_converter():
global layout_converter

View File

@ -140,8 +140,9 @@ class DimSpec:
class ShardingSpec:
"""
Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like
[R, R, S0, S1], which means
Sharding spec describes how to shard a tensor with dim_size dimensions. For example for a 3D tensor, the sharding sequence
[R, S0, S1] means not sharding the first dim, sharding the 3rd along the 1st device mesh axis (Process group)
and sharding the 3th dim along the 2nd device mesh axis. Useful for say, 2D Tensor Parallel.
Argument:
dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,

View File

@ -20,7 +20,12 @@ class ChunkManager:
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
def __init__(
self,
chunk_configuration,
init_device: Optional[torch.device] = None,
reuse_fp16_chunk: bool = True,
) -> None:
self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
self.kwargs_config = chunk_configuration
@ -33,6 +38,10 @@ class ChunkManager:
self.accessed_chunks: Set[Chunk] = set()
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}
self.reuse_fp16_chunk = reuse_fp16_chunk
# Whether model is accumulating gradients,
self.accumulating_grads = False
self.overflow_counter = 0
def register_tensor(
self,

View File

@ -19,6 +19,7 @@ def init_chunk_manager(
model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
reuse_fp16_chunk: bool = True,
verbose: bool = False,
**kwargs,
) -> ChunkManager:
@ -50,5 +51,9 @@ def init_chunk_manager(
)
dist.barrier()
chunk_manager = ChunkManager(config_dict, init_device)
chunk_manager = ChunkManager(
config_dict,
init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
)
return chunk_manager

View File

@ -98,8 +98,14 @@ class GeminiDDP(ModelWrapper):
verbose: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
self.enable_gradient_accumulation = enable_gradient_accumulation
if chunk_config_dict is not None:
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
self.chunk_manager = ChunkManager(
chunk_config_dict,
chunk_init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
)
else:
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
@ -112,6 +118,7 @@ class GeminiDDP(ModelWrapper):
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=zero_group,
reuse_fp16_chunk=reuse_fp16_chunk,
verbose=verbose,
)
self.gemini_manager = GeminiManager(
@ -128,7 +135,6 @@ class GeminiDDP(ModelWrapper):
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
@ -137,14 +143,8 @@ class GeminiDDP(ModelWrapper):
self.zero_group = zero_group or _get_default_group()
self.extra_dp_group = extra_dp_group
self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights
self.enable_gradient_accumulation = enable_gradient_accumulation
if self.enable_gradient_accumulation:
self.reuse_fp16_chunk = False
self.accumulating_grads = False # Whether model is accumulating gradients
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
@ -178,7 +178,29 @@ class GeminiDDP(ModelWrapper):
if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
p._grad_handle = p.register_hook(
partial(
GeminiDDP.grad_handle,
chunk_manager=self.chunk_manager,
param2name=self.param2name,
grads_device=self.grads_device,
master_weights=self.master_weights,
enable_gradient_accumulation=self.enable_gradient_accumulation,
p=p,
)
)
def remove_hooks(self):
for p in self.module.parameters():
if is_ddp_ignored(p):
continue
if p.requires_grad:
assert hasattr(p, "_grad_handle")
p._grad_handle.remove()
delattr(p, "_grad_handle")
def __del__(self):
self.remove_hooks()
def parameters(self, recurse: bool = True):
return self.module.parameters(recurse)
@ -324,8 +346,8 @@ class GeminiDDP(ModelWrapper):
f"{error_str}",
)
self._setup_grads_ptr()
if self.enable_gradient_accumulation and not self.accumulating_grads:
self.accumulating_grads = True # Turn on the state of gradient accumulation.
if self.enable_gradient_accumulation and not self.chunk_manager.accumulating_grads:
self.chunk_manager.accumulating_grads = True # Turn on the state of gradient accumulation.
self._logger.debug(
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
)
@ -340,25 +362,34 @@ class GeminiDDP(ModelWrapper):
def backward_by_grad(self, tensor, grad):
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
def grad_handle(self, p, grad):
@staticmethod
def grad_handle(
grad,
chunk_manager: ChunkManager,
param2name: Dict,
grads_device: Dict,
master_weights: bool,
enable_gradient_accumulation: bool,
p: nn.Parameter,
):
setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
chunk = self.chunk_manager.get_chunk(p)
chunk = chunk_manager.get_chunk(p)
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError(
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
f"Parameter `{param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter."
)
grad_chunk = chunk
if not self.reuse_fp16_chunk:
if not self.accumulating_grads:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
if not chunk_manager.reuse_fp16_chunk:
if not chunk_manager.accumulating_grads:
grad_chunk = chunk_manager.init_grad_chunk(chunk)
else:
assert chunk.grad_chunk is not None
if chunk.grad_chunk not in self.chunk_manager.accessed_chunks:
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
if chunk.grad_chunk not in chunk_manager.accessed_chunks:
grad_chunk = chunk_manager.rearrange_accumulated_grad_chunk(chunk)
else:
grad_chunk = chunk.grad_chunk
chunk.grad_chunk.l2_norm = None
@ -371,33 +402,33 @@ class GeminiDDP(ModelWrapper):
chunk.tensor_trans_state(p, TensorState.HOLD)
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
if not self.accumulating_grads:
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
if not chunk_manager.accumulating_grads:
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
else:
grad_chunk.add_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
reduced = chunk_manager.reduce_chunk(grad_chunk)
if reduced:
if not self.reuse_fp16_chunk:
if not chunk_manager.reuse_fp16_chunk:
if chunk.keep_gathered:
self.chunk_manager.fake_release_chunk(chunk)
chunk_manager.fake_release_chunk(chunk)
else:
self.chunk_manager.release_chunk(chunk)
chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if self.extra_dp_group is not None:
if chunk.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if self.extra_dp_group is not None:
if chunk.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
self.overflow_counter += grad_chunk.has_inf_or_nan
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
if not (self.master_weights) or (self.enable_gradient_accumulation):
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
@ -513,11 +544,11 @@ class GeminiDDP(ModelWrapper):
# get copies of fp32 parameters in CPU
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
params = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params
param_to_save_data = self._get_param_to_save_data(params, only_rank_0)
# get the mapping between copies and fp16 parameters
p_mapping = dict()
if self.reuse_fp16_chunk:
if self.chunk_manager.reuse_fp16_chunk:
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
name = self.param2name[p]
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
@ -713,7 +744,7 @@ class GeminiDDP(ModelWrapper):
name = self.param2name[p]
fp32_to_name[fp32_p] = name
params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
params_to_load = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params
chunk_list = self.chunk_manager.get_chunks(params_to_load)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
@ -728,7 +759,9 @@ class GeminiDDP(ModelWrapper):
shard_fn = tensor.shard_fn
gather_fn = tensor.gather_fn
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
parameter_name = (
fp32_to_name[tensor] if self.chunk_manager.reuse_fp16_chunk else self.param2name[tensor]
)
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
load(
parameter_name,
@ -900,7 +933,7 @@ class GeminiDDP(ModelWrapper):
gathered_param = param if keep_vars else param.detach()
else:
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param
param_to_save = fp16_to_fp32[param] if self.chunk_manager.reuse_fp16_chunk else param
if param_to_save not in gathered_param_buffer:
chunk = self.chunk_manager.get_chunk(param_to_save)
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))

View File

@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
self.module = module
def check_local_overflow(self) -> bool:
return self.module.overflow_counter > 0
return self.module.chunk_manager.overflow_counter > 0
def pre_zero_grad(self) -> None:
self.module.overflow_counter = 0
self.module.chunk_manager.overflow_counter = 0
class GeminiOptimizer(OptimizerWrapper):
@ -202,7 +202,7 @@ class GeminiOptimizer(OptimizerWrapper):
chunk16 = self.param_to_chunk16[fake_param]
begin, end = self.param_to_range[fake_param]
grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk
grad_chunk16 = chunk16 if self.module.chunk_manager.reuse_fp16_chunk else chunk16.grad_chunk
fake_param.data = grad_chunk16.payload[begin:end]
fake_param.grad = fake_param.data
@ -221,14 +221,14 @@ class GeminiOptimizer(OptimizerWrapper):
def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set:
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
grad_chunk.l2_norm = None
def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
for c16 in self.chunk16_set:
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
assert grad_chunk.l2_norm is not None
if grad_chunk.is_gathered:
@ -275,7 +275,7 @@ class GeminiOptimizer(OptimizerWrapper):
self._logger.info(f"Found overflow. Skip step")
self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients
if self.module.reuse_fp16_chunk:
if self.module.chunk_manager.reuse_fp16_chunk:
self._update_fp16_params()
return
@ -288,7 +288,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.zero_grad()
if self.module.master_weights:
self._update_fp16_params()
self.module.accumulating_grads = False
self.module.chunk_manager.accumulating_grads = False
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):

View File

@ -6,6 +6,7 @@ class BaseStore:
def __init__(self, torch_pg: ProcessGroup):
self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg)
self.torch_pg = torch_pg
@property
def world_size(self):

View File

@ -1,16 +1,43 @@
from typing import Dict
from typing import Dict, Optional
import torch
import torch.distributed as dist
from torch import Tensor
from torch._utils import _flatten_dense_tensors
from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from .base_store import BaseStore
class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
def __init__(
self,
torch_pg: ProcessGroup,
reduce_bucket_size: int,
overlap_communication: bool,
communication_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: ProcessGroup = None,
):
super().__init__(torch_pg)
self.reduce_bucket_size = reduce_bucket_size
# communication params
self._overlap_communication = overlap_communication
self._communication_dtype = communication_dtype
if self._overlap_communication:
self.comm_stream = get_accelerator().Stream()
self.zero_local_rank = dist.get_rank(group=self.torch_pg)
self.zero_world_size = dist.get_world_size(group=self.torch_pg)
# extra dp
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
# And moe working and master param are split by extra dp pg.
self.moe_extra_dp_pg = moe_extra_dp_process_group
if self.moe_extra_dp_pg is not None:
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
self.reset_all()
def reset_all(self) -> None:

View File

@ -6,7 +6,7 @@ from .base_store import BaseStore
class GradientStore(BaseStore):
def __init__(self, *args, partition_grad: bool = False):
def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True):
super().__init__(*args)
"""
self._grads_of_params mapping the parameter and its gradient slices
@ -18,9 +18,12 @@ class GradientStore(BaseStore):
}
"""
self._grads_of_params = dict()
# for zero2, it's `param_id: [grad_local_rank]`
# stage 2
self._partition_grads = partition_grad
# grad accumulation
self.require_grad_sync = require_grad_sync
self._working_index = 0 if partition_grad else self._local_rank
# for zero2, it's `param_id: [grad_local_rank]`
self.grad_to_param_mapping = dict()
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:

View File

@ -1,3 +1,5 @@
from typing import Dict
from torch import Tensor
from torch.distributed import ProcessGroup
@ -47,3 +49,12 @@ class ParameterStore(BaseStore):
self.master_to_working_param[id(master_param)] = working_param
self.working_to_master_param[id(working_param)] = master_param
def get_padding_map(self) -> Dict[int, Tensor]:
"""Return the padding map
Returns:
Dict[int, Tensor]: The padding map
"""
return self._padding_map

View File

@ -90,38 +90,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._logger = get_dist_logger()
self._verbose = verbose
# stage 2
self._partition_grads = partition_grad
self._cpu_offload = cpu_offload
# grad accumulation
self.require_grad_sync = True
# if process_group is none, will use the default one
self.dp_pg = dp_process_group
self._local_rank = dist.get_rank(group=self.dp_pg)
self._world_size = dist.get_world_size(group=self.dp_pg)
# extra dp
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
# And moe working and master param are split by extra dp pg.
self.moe_extra_dp_pg = moe_extra_dp_process_group
if self.moe_extra_dp_pg is not None:
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
# working and master params for mixed precision training
self._working_param_groups = dict()
self._master_param_groups_of_current_rank = dict()
# communication params
self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
# gradient clipping
self._clip_grad_norm = clip_grad_norm
@ -140,9 +114,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(self.dp_pg)
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
self._bucket_store = BucketStore(self.dp_pg)
self._param_store = ParameterStore(dp_process_group)
self._grad_store = GradientStore(dp_process_group, partition_grad=partition_grad, require_grad_sync=True)
self._bucket_store = BucketStore(
dp_process_group, reduce_bucket_size, overlap_communication, communication_dtype, moe_extra_dp_process_group
)
# moe param should not be stored in working_groups
# because they have different parallel strategy
@ -157,7 +133,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
group_params = list()
for param in param_group["params"]:
if param.requires_grad:
if self.moe_extra_dp_pg is None:
if self._bucket_store.moe_extra_dp_pg is None:
# skip moe param
if is_moe_tensor(param):
self.working_moe_params.append(param)
@ -194,15 +170,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param_group["params"] = self.master_moe_params
self.optim.param_groups.append(param_group)
# initialize communication stream for
# communication-computation overlapping
if self._overlap_communication:
self._comm_stream = get_accelerator().Stream()
# reduction hook is only used if overlapping communication
# or stage 2 is used
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_communication or self._partition_grads:
if self._bucket_store._overlap_communication or self._grad_store._partition_grads:
self._attach_reduction_hook()
# initialize mixed precision mixin
@ -222,6 +193,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
def __del__(self):
self.remove_hooks()
@property
def dtype(self):
return self._dtype
@ -246,7 +220,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device = "cpu" if self._cpu_offload else get_accelerator().get_current_device()
for param in param_list:
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
padding_size = (
self._bucket_store.zero_world_size - param.numel() % self._bucket_store.zero_world_size
) % self._bucket_store.zero_world_size
self._param_store.record_param_padding_size(param, padding_size)
with torch.no_grad():
@ -258,12 +234,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
else:
padding_param = param.data.view(-1)
if self.moe_extra_dp_pg is not None and is_moe_tensor(param):
splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size)
splited_params = splited_params[self.moe_extra_dp_pg_rank]
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(param):
splited_params = padding_param.split(
padding_param.numel() // self._bucket_store.moe_extra_dp_pg_size
)
splited_params = splited_params[self._bucket_store.moe_extra_dp_pg_rank]
else:
splited_params = padding_param.split(padding_param.numel() // self._world_size)
splited_params = splited_params[self._local_rank]
splited_params = padding_param.split(padding_param.numel() // self._bucket_store.zero_world_size)
splited_params = splited_params[self._bucket_store.zero_local_rank]
# use fp32 when master_weights is True
if self._master_weights is True:
@ -271,6 +249,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
else:
splited_param_current_rank = splited_params
# Send the splited view to the optimizer to match ZeRO 2 grad shape
params_current_rank.append(splited_param_current_rank)
self._param_store.link_master_and_working_param(splited_param_current_rank, param)
@ -280,10 +259,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Backward Reduction Hook #
###########################
def _grad_handler(self, group_id, param):
@staticmethod
def grad_handler(
param: nn.Parameter,
group_id: int,
bucket_store: BucketStore,
param_store: ParameterStore,
grad_store: GradientStore,
):
# if run with no_sync context, would not sync grad when backward
if self.require_grad_sync:
self._add_to_bucket(param, group_id)
if grad_store.require_grad_sync:
LowLevelZeroOptimizer.add_to_bucket(param, group_id, bucket_store, param_store, grad_store)
def _attach_reduction_hook(self):
# we iterate over the working params
@ -292,29 +278,36 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id))
param._grad_handle = param.register_post_accumulate_grad_hook(
partial(
LowLevelZeroOptimizer.grad_handler,
group_id=group_id,
bucket_store=self._bucket_store,
param_store=self._param_store,
grad_store=self._grad_store,
)
)
#######################
# Reduction Functions #
#######################
def _run_reduction(self):
if self._bucket_store.num_elements_in_bucket() > 0:
self._bucket_store.build_grad_in_bucket()
if self.moe_extra_dp_pg is None:
flat_grads = self._bucket_store.get_flatten_grad()
flat_grads /= self._world_size
@staticmethod
def run_reduction(bucket_store: BucketStore, grad_store: GradientStore):
if bucket_store.num_elements_in_bucket() > 0:
bucket_store.build_grad_in_bucket()
if bucket_store.moe_extra_dp_pg is None:
flat_grads = bucket_store.get_flatten_grad()
flat_grads /= bucket_store.zero_world_size
else:
# record moe and non moe param
moe_list = []
for param in self._bucket_store._param_list:
for param in bucket_store._param_list:
moe_list.append(is_moe_tensor(param))
# divide them into different groups
moe_grad_list = []
non_moe_grad_list = []
for grad_list in self._bucket_store._grad_in_bucket.values():
for grad_list in bucket_store._grad_in_bucket.values():
non_moe_cur_grad = []
moe_cur_grad = []
for i in range(len(grad_list)):
@ -332,7 +325,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for grad_list in non_moe_grad_list:
non_moe_flat_grads.append(_flatten_dense_tensors(grad_list))
non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads)
non_moe_flat_grads /= self._world_size
non_moe_flat_grads /= bucket_store.zero_world_size
if len(moe_grad_list) > 0:
moe_flat_grads = []
@ -341,12 +334,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
moe_flat_grads = _flatten_dense_tensors(moe_flat_grads)
# ready to add other tensors to bucket
self._bucket_store.reset_num_elements_in_bucket()
bucket_store.reset_num_elements_in_bucket()
if self._overlap_communication:
stream = self._comm_stream
if bucket_store._overlap_communication:
stream = bucket_store.comm_stream
# in case of the memory being reused in the default stream
if self.moe_extra_dp_pg is None:
if bucket_store.moe_extra_dp_pg is None:
flat_grads.record_stream(stream)
else:
if len(non_moe_grad_list) > 0:
@ -359,53 +352,63 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
stream = get_accelerator().current_stream()
with get_accelerator().stream(stream):
group_id = self._bucket_store.current_group_id
group_id = bucket_store.current_group_id
if self.moe_extra_dp_pg is None:
if bucket_store.moe_extra_dp_pg is None:
grad_dtype = flat_grads.dtype
if self._communication_dtype is not None:
flat_grads = flat_grads.to(self._communication_dtype)
if bucket_store._communication_dtype is not None:
flat_grads = flat_grads.to(bucket_store._communication_dtype)
if not self._partition_grads:
if self.moe_extra_dp_pg is None:
dist.all_reduce(flat_grads, group=self.dp_pg)
if not grad_store._partition_grads:
if bucket_store.moe_extra_dp_pg is None:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
grad_in_bucket = self._bucket_store.get_grad()
self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id)
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.zero_world_size)
grad_in_bucket = bucket_store.get_grad()
LowLevelZeroOptimizer.update_unpartitoned_grad(
bucket_store, grad_store, grad_in_bucket.values(), flat_grads_per_rank, group_id
)
# sync extra zero group
else:
# sync non moe param in global dp group
if len(non_moe_grad_list) > 0:
dist.all_reduce(non_moe_flat_grads, group=self.dp_pg)
dist.all_reduce(non_moe_flat_grads, group=bucket_store.torch_pg)
flat_grads_per_rank = non_moe_flat_grads.split(
non_moe_flat_grads.numel() // self._world_size
non_moe_flat_grads.numel() // bucket_store.zero_world_size
)
LowLevelZeroOptimizer.update_unpartitoned_grad(
bucket_store, grad_store, non_moe_grad_list, flat_grads_per_rank, group_id
)
self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id)
# sync moe param only in zero group
if len(moe_grad_list) > 0:
dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg)
flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size)
self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id)
dist.all_reduce(moe_flat_grads, group=bucket_store.moe_extra_dp_pg)
flat_grads_per_rank = moe_flat_grads.split(
moe_flat_grads.numel() // bucket_store.zero_world_size
)
LowLevelZeroOptimizer.update_unpartitoned_grad(
bucket_store, grad_store, moe_grad_list, flat_grads_per_rank, group_id
)
else:
if self.moe_extra_dp_pg is None:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
if bucket_store.moe_extra_dp_pg is None:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.zero_world_size))
received_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
if received_grad.dtype != grad_dtype:
received_grad = received_grad.to(grad_dtype)
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1)
grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank]
LowLevelZeroOptimizer.update_partitoned_grad(
bucket_store, grad_store, grad_in_bucket_current_rank, received_grad, group_id, 1
)
else:
# categorize moe and non moe param
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank]
moe_grad_in_bucket_current_rank = []
non_moe_grad_in_bucket_current_rank = []
for idx, grad in enumerate(grad_in_bucket_current_rank):
@ -416,48 +419,61 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if len(non_moe_grad_list) > 0:
flat_grads_list = list(
non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size)
non_moe_flat_grads.split(len(non_moe_flat_grads) // bucket_store.zero_world_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
self._update_partitoned_grad(
received_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
LowLevelZeroOptimizer.update_partitoned_grad(
bucket_store,
grad_store,
non_moe_grad_in_bucket_current_rank,
recieved_grad,
received_grad,
group_id,
1,
)
if len(moe_grad_list) > 0:
flat_grads_list = list(
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
moe_flat_grads.split(len(moe_flat_grads) // bucket_store.moe_extra_dp_pg_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
received_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(
recieved_grad,
received_grad,
flat_grads_list,
group=self.moe_extra_dp_pg,
group=bucket_store.moe_extra_dp_pg,
)
param_slice = self._world_size // self.moe_extra_dp_pg_size
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
for split_recieved_grad in recieved_grad:
param_slice = bucket_store.zero_world_size // bucket_store.moe_extra_dp_pg_size
received_grad = list(received_grad.split(len(received_grad) // param_slice))
for split_recieved_grad in received_grad:
split_recieved_grad = _unflatten_dense_tensors(
split_recieved_grad, moe_grad_in_bucket_current_rank
)
for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank):
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(real_grad, param_slice, group_id, param_id)
param_id = bucket_store.get_param_id_of_grad(grad)
LowLevelZeroOptimizer.add_grad(
grad_store, real_grad, param_slice, group_id, param_id
)
self._bucket_store.reset()
bucket_store.reset()
def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None:
@staticmethod
def update_unpartitoned_grad(
bucket_store: BucketStore,
grad_store: GradientStore,
origin_grad_list: List,
flat_grad_list: List,
group_id: int,
) -> None:
for rank, grad_list in enumerate(origin_grad_list):
sync_tensor(flat_grad_list[rank], grad_list)
for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, self._world_size, group_id, param_id, rank)
param_id = bucket_store.get_param_id_of_grad(grad)
LowLevelZeroOptimizer.add_grad(grad_store, grad, bucket_store.zero_world_size, group_id, param_id, rank)
def _update_partitoned_grad(
self,
@staticmethod
def update_partitoned_grad(
bucket_store: BucketStore,
grad_store: GradientStore,
origin_grad_list: List,
flat_grad: torch.Tensor,
group_id: int,
@ -465,23 +481,31 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
) -> None:
sync_tensor(flat_grad, origin_grad_list)
for grad in origin_grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, partition_num, group_id, param_id)
param_id = bucket_store.get_param_id_of_grad(grad)
LowLevelZeroOptimizer.add_grad(grad_store, grad, partition_num, group_id, param_id)
def _add_grad(
self,
@staticmethod
def add_grad(
grad_store: GradientStore,
grad: torch.Tensor,
partition_num: int,
group_id: int,
param_id: int,
rank: int = 0,
) -> None:
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
if len(grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
def _add_to_bucket(self, param, group_id):
@staticmethod
def add_to_bucket(
param: nn.Parameter,
group_id: int,
bucket_store: BucketStore,
param_store: ParameterStore,
grad_store: GradientStore,
):
param_size = param.numel()
# check if the bucket is full
@ -489,13 +513,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# or got a grad of param from another group
# after reduction, the bucket will be empty
if (
self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size
or group_id != self._bucket_store.current_group_id
bucket_store.num_elements_in_bucket() + param_size > bucket_store.reduce_bucket_size
or group_id != bucket_store.current_group_id
):
self._run_reduction()
LowLevelZeroOptimizer.run_reduction(bucket_store, grad_store)
padding_size = self._param_store.get_param_padding_size(param)
self._bucket_store.add_param_grad(group_id, param, padding_size)
padding_size = param_store.get_param_padding_size(param)
bucket_store.add_param_grad(group_id, param, padding_size)
################################
# torch.optim.Optimizer methods
@ -503,7 +527,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def backward(self, loss, retain_graph=False):
assert not (
self._partition_grads and not self.require_grad_sync
self._grad_store._partition_grads and not self._grad_store.require_grad_sync
), "ZeRO2(partition_grads) and no_sync are not compatible"
if self.mixed_precision_mixin is not None:
@ -511,31 +535,31 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
loss.backward(retain_graph=retain_graph)
if not self.require_grad_sync:
if not self._grad_store.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
self._reduce_grad(self._grad_store._partition_grads)
# clear reduced grads
if self._overlap_communication:
if self._bucket_store._overlap_communication:
get_accelerator().synchronize()
self.zero_grad()
def backward_by_grad(self, tensor, grad):
assert not (
self._partition_grads and not self.require_grad_sync
self._grad_store._partition_grads and not self._grad_store.require_grad_sync
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)
if not self.require_grad_sync:
if not self._grad_store.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
self._reduce_grad(self._grad_store._partition_grads)
# clear reduced grads
if self._overlap_communication:
if self._bucket_store._overlap_communication:
get_accelerator().synchronize()
self.zero_grad()
@ -566,7 +590,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def step(self, closure=None):
assert closure is None, "closure is not supported by step()"
if not self.require_grad_sync:
if not self._grad_store.require_grad_sync:
return
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
@ -585,7 +609,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# and should not be updated
real_working_params = dict()
real_master_params = dict()
grad_index = 0 if self._partition_grads else self._local_rank
grad_index = 0 if self._grad_store._partition_grads else self._bucket_store.zero_local_rank
for group_id in range(self.num_param_groups):
master_params = self._master_param_groups_of_current_rank[group_id]
real_working_params[group_id] = []
@ -598,14 +622,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
if len(grads) > 0:
# moe hybrid zero
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
real_working_params[group_id].append(working_param)
if self._partition_grads:
if self._grad_store._partition_grads:
grad = grads
else:
param_slice = self._world_size // self.moe_extra_dp_pg_size
param_slice = self._bucket_store.zero_world_size // self._bucket_store.moe_extra_dp_pg_size
grad = grads[
self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice
self._bucket_store.moe_extra_dp_pg_rank
* param_slice : (self._bucket_store.moe_extra_dp_pg_rank + 1)
* param_slice
]
grad = flatten(grad)
else:
@ -674,25 +700,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
all_splited_param = [
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self.moe_extra_dp_pg_size)
for _ in range(self._bucket_store.moe_extra_dp_pg_size)
]
dist.all_gather(
all_splited_param,
splited_param.to(device).to(self._dtype),
group=self.moe_extra_dp_pg,
group=self._bucket_store.moe_extra_dp_pg,
)
else:
all_splited_param = [
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self._world_size)
for _ in range(self._bucket_store.zero_world_size)
]
dist.all_gather(
all_splited_param,
splited_param.to(device).to(self._dtype),
group=self.dp_pg,
group=self._bucket_store.torch_pg,
)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
@ -720,7 +746,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self._bucket_store.torch_pg)
total_norm = total_norm_cuda.item()
else:
@ -738,7 +764,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
torch.distributed.all_reduce(
total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_pg,
group=self._bucket_store.torch_pg,
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
@ -773,27 +799,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad and param.grad is not None:
self._add_to_bucket(param, group_id)
LowLevelZeroOptimizer.add_to_bucket(
param,
group_id,
self._bucket_store,
self._param_store,
self._grad_store,
)
self._run_reduction()
LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store)
def _reduce_grad(self, partition_grad):
# if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients
if not partition_grad and not self._overlap_communication:
if not partition_grad and not self._bucket_store._overlap_communication:
self._sync_grad()
else:
self._run_reduction()
LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store)
# this context comes from pytorch DDP
@contextmanager
def no_sync(self):
old_require_grad_sync = self.require_grad_sync
self.require_grad_sync = False
old_require_grad_sync = self._grad_store.require_grad_sync
self._grad_store.require_grad_sync = False
try:
yield
finally:
self.require_grad_sync = old_require_grad_sync
self._grad_store.require_grad_sync = old_require_grad_sync
##############
# State Dict #
@ -833,16 +865,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
working_param = self._param_store.master_to_working_param[id(param)]
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
gather_tensor = [
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
torch.zeros(v.shape, device=device, dtype=v.dtype)
for _ in range(self._bucket_store.moe_extra_dp_pg_size)
]
dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg)
dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg)
else:
gather_tensor = [
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
torch.zeros(v.shape, device=device, dtype=v.dtype)
for _ in range(self._bucket_store.zero_world_size)
]
dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg)
dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.torch_pg)
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
@ -862,17 +896,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for param_idx, state in zero_state_dict["state"].items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size
padding_size = (
self._bucket_store.zero_world_size - v.numel() % self._bucket_store.zero_world_size
) % self._bucket_store.zero_world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
v_list = v.split(v.numel() // self.moe_extra_dp_pg_size)
zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone()
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
v_list = v.split(v.numel() // self._bucket_store.moe_extra_dp_pg_size)
zero_state_dict["state"][param_idx][k] = (
v_list[self._bucket_store.moe_extra_dp_pg_rank].detach().clone()
)
else:
v_list = v.split(v.numel() // self._world_size)
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
v_list = v.split(v.numel() // self._bucket_store.zero_world_size)
zero_state_dict["state"][param_idx][k] = (
v_list[self._bucket_store.zero_local_rank].detach().clone()
)
self.optim.load_state_dict(zero_state_dict)
@ -904,16 +944,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
state_tensor = [
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
torch.zeros(v.shape, device=device, dtype=v.dtype)
for _ in range(self._bucket_store.moe_extra_dp_pg_size)
]
dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg)
dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg)
else:
state_tensor = [
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
torch.zeros(v.shape, device=device, dtype=v.dtype)
for _ in range(self._bucket_store.zero_world_size)
]
dist.all_gather(state_tensor, v.to(device), group=self.dp_pg)
dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.torch_pg)
state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
@ -944,14 +986,30 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
if self.moe_extra_dp_pg is not None and is_moe_tensor(p):
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(p):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
master_param.copy_(
working_param.chunk(self._bucket_store.zero_world_size)[self._bucket_store.zero_local_rank]
)
if hasattr(self, "master_moe_params"):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.copy_(working_moe_param)
def remove_hooks(self) -> None:
"""remove the registered hooks
Args:
plugin (LowLevelZeroPlugin): the plugin to bound this method.
"""
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
assert hasattr(param, "_grad_handle")
param._grad_handle.remove()
delattr(param, "_grad_handle")
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param
@ -962,3 +1020,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
**self.moe_master_to_working_map,
}
return self._param_store.master_to_working_param
def get_param_padding_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.get_padding_map()

View File

@ -413,7 +413,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
环境要求:
- PyTorch >= 1.11 并且 PyTorch <= 2.1
- PyTorch >= 2.1
- Python >= 3.7
- CUDA >= 11.0
- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher)

View File

@ -0,0 +1,141 @@
# Distributed Optimizers
Author: [Wenxuan Tan](https://github.com/Edenzzzz), [Junwen Duan](https://github.com/duanjunwen), [Renjie Mao](https://github.com/chongqichuizi875)
**Related Paper**
- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)
- [CAME: Confidence-guided Adaptive Memory Efficient Optimization] (https://arxiv.org/abs/2307.02047)
- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection] (https://arxiv.org/abs/2403.03507)
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)
## Introduction
Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to efficiently update parameters, and are thus not directly applicable to parallel settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO using plugins.
## Optimizers
Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant.
## API Reference
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
## Hands-On Practice
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs.
### step 1. Import libraries
```python
from transformers import LlamaModel, LlamaConfig
from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
import colossalai
import torch
```
### step 2. Initialize Distributed Environment and Parallism Group
We need to initialize distributed environment. For demo purpose, we use `colossal run --nproc_per_node 4`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md)
```python
colossalai.launch_from_torch()
```
### step 3. Initialize Module and Optimizer
Build our model. We created an MLP using two Linear Layer.
```python
# Init Llama from huggingface
configuration = LlamaConfig()
model = LlamaModel(configuration).cuda()
criterion = lambda x: x.mean()
dist_optim = DistributedAdaFactor(model.parameters())
```
### step 4.Init Booster
```python
plugin = HybridParallelPlugin(tp_size=2, zero_stage=2, pp_size=1, enable_all_optimization=True)
booster = Booster(plugin=plugin)
# You should also pass in your own dataset.
model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion)
```
### step 5.Train Your Model
```python
steps = 10
for step in range(steps):
input_ids = torch.ones(1, 100, device="cuda", dtype=torch.int)
attention_mask = input_ids.clone()
outputs = model(input_ids.cuda(), attention_mask.cuda())
loss = criterion(outputs.last_hidden_state)
booster.backward(loss, dist_optim)
dist_optim.step()
dist_optim.zero_grad()
```
### GaLore special handling
For GaLore, we need to specify projection rank for each parameter group and quantization & paged optimizer params. Please refer to bitandbytes for quantization details. Support for ZeRO is underway.
```python
from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.nn.optimizer import DistGaloreAwamW
optim = DistGaloreAwamW(
get_galore_param_groups(model, decay=1e-2, rank=8),
lr=lr,
betas=(beta1, beta2),
eps=eps,
nbits=8,
percentile_clipping=100,
block_wise=True,
min_8bit_size=4096,
)
```
## Plugin compatibility
<table>
<tr>
<th nowrap="nowrap">Model/Feature</th>
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th>
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th>
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th>
<th nowrap="nowrap" align="center" title="CAME">CAME</th>
</tr>
<tr>
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Gemini<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td colspan="39"></td>
</tr>
</table>
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->

View File

@ -1,7 +1,7 @@
# Setup
Requirements:
- PyTorch >= 1.11 and PyTorch <= 2.1
- PyTorch >= 2.1
- Python >= 3.7
- CUDA >= 11.0
- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher)

View File

@ -0,0 +1,141 @@
# 分布式优化器
Author: Wenxuan Tan, Junwen Duan, Renjie Mao
**相关论文**
- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)
- [CAME: Confidence-guided Adaptive Memory Efficient Optimization] (https://arxiv.org/abs/2307.02047)
- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection] (https://arxiv.org/abs/2403.03507)
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)
## 介绍
除了广泛采用的Adam和SGD外许多现代优化器需要逐层统计信息以有效更新参数因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现并且通过插件与Tensor Parallel、DDP和ZeRO无缝集成。
## 优化器
Adafactor 是一种首次采用非负矩阵分解NMF的 Adam 变体用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现
## API 参考
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
## 使用
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs.
### step 1. 导包
```python
from transformers import LlamaModel, LlamaConfig
from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
import colossalai
import torch
```
### step 2. 初始化分布式
We need to initialize distributed environment. For demo purpose, we use `colossal run --nproc_per_node 4`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md)
```python
colossalai.launch_from_torch()
```
### step 3. 初始化模型和优化器
Build our model. We created an MLP using two Linear Layer.
```python
configuration = LlamaConfig()
model = LlamaModel(configuration).cuda()
criterion = lambda x: x.mean()
dist_optim = DistributedAdaFactor(model.parameters())
```
### step 4.初始化booster和plugin
```python
plugin = HybridParallelPlugin(tp_size=2, zero_stage=2, pp_size=1, enable_all_optimization=True)
booster = Booster(plugin=plugin)
# You should also pass in your own dataset.
model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion)
```
### step 5.训练
```python
steps = 10
for step in range(steps):
input_ids = torch.ones(1, 100, device="cuda", dtype=torch.int)
attention_mask = input_ids.clone()
outputs = model(input_ids.cuda(), attention_mask.cuda())
loss = criterion(outputs.last_hidden_state)
booster.backward(loss, dist_optim)
dist_optim.step()
dist_optim.zero_grad()
```
### GaLore的特殊初期
对于 GaLore我们需要为每个参数组指定投影rank以及量化和分页优化器参数。有关量化的详细信息请参考 bitandbytes.
```python
from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.nn.optimizer import DistGaloreAwamW
optim = DistGaloreAwamW(
get_galore_param_groups(model, decay=1e-2, rank=8),
lr=lr,
betas=(beta1, beta2),
eps=eps,
nbits=8,
percentile_clipping=100,
block_wise=True,
min_8bit_size=4096,
)
```
## 兼容性
<table>
<tr>
<th nowrap="nowrap">Model/Feature</th>
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th>
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th>
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th>
<th nowrap="nowrap" align="center" title="CAME">CAME</th>
</tr>
<tr>
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Gemini<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td colspan="39"></td>
</tr>
</table>
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->

View File

@ -2,7 +2,7 @@
环境要求:
- PyTorch >= 1.11 并且 PyTorch <= 2.1
- PyTorch >= 2.1
- Python >= 3.7
- CUDA >= 11.0
- [NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus) >= 7.0 (V100/RTX20 and higher)

View File

@ -243,7 +243,12 @@ def main():
# ====================================
# gpt2 pretrained model
cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
cfg = AutoConfig.from_pretrained(
model_name,
num_labels=data_builder.num_labels,
pad_token=data_builder.tokenizer.pad_token,
pad_token_id=data_builder.tokenizer.pad_token_id,
)
if model_name == "gpt2":
model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()

View File

@ -20,3 +20,4 @@ transformers==4.36.2
peft>=0.7.1
bitsandbytes>=0.39.0
rpyc==6.0.0
galore_torch

View File

@ -1,4 +1,5 @@
from .hanging_param_model import *
from .nested_model import *
from .repeated_computed_layers import *
from .simple_mlp import *
from .simple_net import *

View File

@ -0,0 +1,61 @@
from copy import deepcopy
import torch
import torch.nn as nn
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from ..registry import model_zoo
_BS = 16
_IN_DIM = 32
_HID_DIM = 128
class Net(nn.Module):
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=False, dtype=torch.float32):
super().__init__()
if identity:
self.fc0 = nn.Identity()
else:
self.fc0 = nn.Linear(in_dim, in_dim).to(dtype=dtype)
self.fc1 = nn.Linear(in_dim, hid_dim).to(dtype=dtype)
self.fc2 = nn.Linear(hid_dim, in_dim).to(dtype=dtype)
def forward(self, x):
return self.fc2(self.fc1(self.fc0(x)))
class TPNet(nn.Module):
def __init__(
self,
fc0=nn.Linear(_IN_DIM, _IN_DIM),
fc1=nn.Linear(_IN_DIM, _HID_DIM),
fc2=nn.Linear(_HID_DIM, _IN_DIM),
tp_group=None,
dtype=torch.float32,
):
super().__init__()
self.fc0 = deepcopy(fc0)
self.fc1 = Linear1D_Col.from_native_module(
deepcopy(fc1), process_group=tp_group, gather_output=False, overlap=True, dtype=dtype
)
self.fc2 = Linear1D_Row.from_native_module(
deepcopy(fc2), process_group=tp_group, parallel_input=True, dtype=dtype
)
def forward(self, x):
return self.fc2(self.fc1(self.fc0(x)))
def data_gen():
return torch.randn(_BS, _IN_DIM)
def output_transform(x: torch.Tensor):
return x
model_zoo.register(name="simple_mlp", model_fn=Net, data_gen_fn=data_gen, output_transform_fn=output_transform)
model_zoo.register(name="simple_tp_mlp", model_fn=TPNet, data_gen_fn=data_gen, output_transform_fn=output_transform)

View File

@ -17,3 +17,8 @@ try:
from .mistral import *
except ImportError:
print("This version of transformers doesn't support mistral.")
try:
from .qwen2 import *
except ImportError:
print("This version of transformers doesn't support qwen2.")

View File

@ -0,0 +1,89 @@
import torch
import transformers
from ..registry import ModelAttribute, model_zoo
try:
from transformers import Qwen2Config
HAS_QWEN2 = True
except ImportError:
HAS_QWEN2 = False
if HAS_QWEN2:
# ===============================
# Register Qwen2
# ===============================
def data_gen():
# the input ids are corresponding to the sentence
# 'Hello, my dog is cute'
#
# the code is give below:
# -----------------------------------
# from transformers import Qwen2TokenizerFast
# tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen1.5-7B-Chat")
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# -----------------------------------
input_ids = torch.Tensor(
[[9707, 11, 847, 5562, 374, 13, 123, 18838], [9707, 11, 847, 5562, 374, 17, 89, 18838]]
).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm
def data_gen_for_casual_lm():
data = data_gen()
labels = data["input_ids"].clone()
data["labels"] = labels
return data
# transform the output to a dict
output_transform_fn = lambda x: x
# function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = Qwen2Config(
hidden_size=128,
intermediate_size=256,
max_window_layers=4,
num_attention_heads=16,
num_hidden_layers=4,
num_key_value_heads=16,
)
config.pad_token_id = 0
# register the following models
# transformers.Qwen2Model,
# transformers.Qwen2ForCausalLM,
# transformers.Qwen2ForSequenceClassification,
model_zoo.register(
name="transformers_qwen2",
model_fn=lambda: transformers.Qwen2Model(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_qwen2_for_casual_lm",
model_fn=lambda: transformers.Qwen2ForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_qwen2_for_sequence_classification",
model_fn=lambda: transformers.Qwen2ForSequenceClassification(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@ -80,7 +80,6 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
skipped_models.append(name)
continue
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
get_accelerator().empty_cache()
if err is None:

View File

@ -0,0 +1,272 @@
import torch
import torch.distributed as dist
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer._operation import _gather
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
def check_optim_states(org_optim, sharded_optim):
for group in org_optim.param_groups:
for p in group["params"]:
sharded_state = sharded_optim.state[p]
state = org_optim.state[p]
for key in sharded_state:
assert_close(state[key], sharded_state[key], rtol=1e-5, atol=1e-5)
def check_bert_fwd_bwd(
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, optim_class, sharded_optim_class
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check weights
if test_config["precision"] == "bf16":
atol, rtol = 5e-4, 1e-4
else:
atol, rtol = 5e-4, 5e-4
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check optim states
check_optim_states(org_optimizer, sharded_optimizer.optim)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 1,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 2,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 4,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 1,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "fp16",
},
{
"tp_size": 2,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "fp16",
},
{
"tp_size": 4,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "fp16",
},
{
"tp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
},
{
"tp_size": 2,
"num_microbatches": 4,
"zero_stage": 0,
"precision": "bf16",
},
],
)
def run_bert_test(test_config, optim_class, sharded_optim_class):
"""Only call this if you've initialized distributed backend and spawned processes"""
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**15 # avoid overflow
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_bert_fwd_bwd(
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def _run_bert_test(rank, world_size, port, optim_class, sharded_optim_class):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_bert_test(optim_class, sharded_optim_class)
def check_optim_on_bert(optim_class, sharded_optim_class):
spawn(_run_bert_test, 4, optim_class, sharded_optim_class)
def check_dist_optim_state(org_optimizer, sharded_optimizer):
torch.set_default_dtype(torch.bfloat16)
for group, tp_group in zip(org_optimizer.param_groups, sharded_optimizer.param_groups):
for p, tp in zip(group["params"], tp_group["params"]):
p_state = org_optimizer.state[p]
tp_state = sharded_optimizer.state[tp]
# TODO "exp_avg_sq_col", "exp_avg_sq_row", "exp_avg_sq"
for key in ["exp_avg_sq_row"]:
if key in tp_state.keys() and type(tp_state[key]) is torch.Tensor:
tp_is_dtensor = sharded_optimizer.param_is_dtensor_dict[id(tp)]
shard_spec = sharded_optimizer.shard_spec_dict[id(tp)]
use_zero = sharded_optimizer.use_zero
tp_optim_state = tp_state[key]
p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape
dp_size, tp_size = (
sharded_optimizer.dp_size,
sharded_optimizer.tp_size,
)
# we start init model with first tensor parallel then zero;
# So, we gather model with first zero then tensor parallel
if tp_is_dtensor:
# col parallel
if shard_spec.sharding_sequence[0] == "R":
if use_zero:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
pass
# gather from tp group
# sq_row don need gather alone tp group
if key == "exp_avg_sq_row":
pass
# sq_col need gather alone dp group
if key == "exp_avg_sq_col":
tp_optim_state = _gather(
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
)
tp_optim_state.shape
# row parallel
if shard_spec.sharding_sequence[-1] == "R":
if use_zero:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
if p_state[key].shape[0] // tp_size % dp_size != 0:
pass
else:
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
pass
# gather from tp group
# sq_row need gather alone tp group
if key == "exp_avg_sq_row":
tp_optim_state = _gather(
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
)
tp_optim_state.shape
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
if use_zero:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
# row residule; no gather
if p_state[key].shape[0] % dp_size != 0:
pass
else:
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
tp_optim_state = tp_optim_state.div_(dp_size)
# need a div;
else:
pass
# Sovled a New issus: different dtype;
# So far, only happen in H100 env;
# Seem torch.set_default_dtype(torch.bfloat16) not act on booster.percision;
# Or assert_close just update to check dtype;
if p_state[key].dtype != tp_optim_state.dtype:
tp_optim_state = tp_optim_state.type(p_state[key].dtype)
try:
assert_close(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2)
except:
pass
def check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol):
for (org_name, org_param), (sharded_name, sharded_param) in zip(
org_model.named_parameters(), sharded_model.named_parameters()
):
if org_name in weight_layer_for_check:
assert_close(org_param, sharded_param, atol=atol, rtol=rtol)
def check_dist_grad(sharded_optimizer, org_model, sharded_model, weight_layer_for_check, atol, rtol):
for (org_name, org_param), (sharded_name, sharded_param) in zip(
org_model.named_parameters(), sharded_model.named_parameters()
):
if org_name in weight_layer_for_check:
org_grad = org_param.grad
group_id = dist.get_rank(sharded_optimizer.optim.dp_group)
dist_grad = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(group_id, id(sharded_param))
# dist_grad concat then reshape to org_grad shape
if dist_grad:
dist_grad = torch.cat([t for t in dist_grad], 0).view(org_grad.shape)
assert_close(org_grad, dist_grad, atol=atol, rtol=rtol)

View File

@ -0,0 +1,698 @@
import copy
import pytest
import torch
import torch.distributed as dist
from torch import nn
from torch.testing import assert_close
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer.adafactor import Adafactor
from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer._operation import _gather
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor import (
distribute_tensor,
get_device_mesh,
get_layout,
get_sharding_spec,
is_distributed_tensor,
shard_colwise,
shard_rowwise,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
from tests.test_optimizer._utils import check_dist_optim_state, check_dist_param, check_optim_states
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
build_model_from_low_level_zero_plugin,
check_weight,
run_forward_backward_with_hybrid_plugin,
run_forward_backward_with_low_level_zero_plugin,
unwrap_model,
)
HEIGHT = 4
WIDTH = 4
_TP_SPEC = DimSpec([0])
def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float32:
rtol = 5e-04
atol = 5e-04
elif dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
# return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol))
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
# setup param groups; (For zero test optim)
def setup_param_groups_zero(model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
# setup param groups; (For base optim)
def setup_param_groups(model: nn.Module) -> list:
optimizer_grouped_parameters = [p for n, p in model.named_parameters()]
return optimizer_grouped_parameters
# setup flatten param groups, sharding spec and shape; (For dist optim)
def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict:
flatten_optimizer_grouped_parameters = []
sharding_spec = {} # {id(flatten param): get_layout(p).global_shape}
param_shape = {} # {id(flatten param): get_sharding_spec(p)}
for n, p in model.named_parameters():
# flatten_p = copy.deepcopy(p).flatten()
flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True))
flatten_optimizer_grouped_parameters.append(flatten_p)
if is_distributed_tensor(p):
sharding_spec[id(flatten_p)] = get_sharding_spec(p)
param_shape[id(flatten_p)] = get_layout(p).global_shape
else:
sharding_spec[id(flatten_p)] = None
param_shape[id(flatten_p)] = p.shape
return flatten_optimizer_grouped_parameters, sharding_spec, param_shape
def set_dist_grad(
dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup
) -> None:
"""
Set split grads for Tensor Parallel or ZeRO DP.
We do not need a separate treatment for ZeRO,
as the wrapper takes care of reduce-scattering grads.
"""
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
if torch_p.grad is None:
torch_p.grad = torch.zeros_like(torch_p)
is_distributed = hasattr(p, "dist_layout")
if is_distributed:
sharding = p.dist_layout.sharding_spec.sharding_sequence
split_dim = sharding.index(_TP_SPEC)
shape = torch_p.split(world_size, dim=split_dim)[rank].shape
indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1))
# Generate grads only for the correctly split chunk
torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype))
else:
shape = torch_p.shape
torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype)
# avoid inconsistent grad and param dtype error
orig_p = p.data
p.data = torch_p.grad.clone().to(g_dtype)
p.grad = p.data
p.data = orig_p
def set_master_param_to_shard_param(master_param_list) -> dict:
master_param_to_shard_param = {id(p): p for p in master_param_list}
return master_param_to_shard_param
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(HEIGHT, WIDTH)
self.linear2 = nn.Linear(WIDTH, HEIGHT)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
class TPModel(nn.Module):
def __init__(self, linear1, linear2, tp_group=None):
super().__init__()
self.linear1 = Linear1D_Col.from_native_module(
linear1, process_group=tp_group, gather_output=False, overlap=True
)
self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
@parameterize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # torch.float32, torch.float16, torch.bfloat16
@parameterize("tp_zero_size", [(4, 1)])
def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
tp_size, zero_size = tp_zero_size
local_rank = dist.get_rank()
use_zero = True if zero_size > 1 else False
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1)
torch.set_default_dtype(dtype)
set_seed(42)
# ==============================
# Base Case
# ==============================
H, W = HEIGHT, WIDTH
model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight
weight, bias = model_col.weight, model_col.bias
# ==============================
# Col Parallel
# ==============================
weight_col_shard = shard_colwise(weight.clone(), tp_group)
weight_col_shard_layout = get_layout(weight_col_shard) # Layout info weight_col_shard_layout.global_shape
weight_col_shard_shard_spec = get_sharding_spec(weight_col_shard) # Shard spec
weight_col_shard_flatten = nn.Parameter(weight_col_shard.clone().flatten().requires_grad_(True))
bias_col_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True))
# ==============================
# Row Parallel
# ==============================
weight_row_shard = shard_rowwise(weight.clone(), tp_group)
weight_row_shard_layout = get_layout(weight_row_shard) # Layout info weight_row_shard_layout.global_shape
weight_row_shard_shard_spec = get_sharding_spec(weight_row_shard) # Shard spec
weight_row_shard_flatten = nn.Parameter(
weight_row_shard.clone().flatten().requires_grad_(True)
) # flatten input(not dtensor) to optimizer
bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True))
# base_param_group = setup_param_groups([weight, bias])
# cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten])
# rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten])
# ==============================
# Init Optimizer
# ==============================
# base
optimizer_base = Adafactor([weight, bias])
cp_dist_optim = DistributedAdaFactor([weight_col_shard_flatten, bias_col_flatten])
rp_dist_optim = DistributedAdaFactor([weight_row_shard_flatten, bias_row_flatten])
shard_to_param_cp = set_master_param_to_shard_param([weight_col_shard_flatten, bias_col_flatten])
cp_dist_optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param_cp,
use_zero=use_zero,
)
shard_to_param_rp = set_master_param_to_shard_param([weight_row_shard_flatten, bias_row_flatten])
rp_dist_optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param_rp,
use_zero=use_zero,
)
N_STEPS = 1
for _ in range(N_STEPS):
# base step
optimizer_base.zero_grad()
weight.grad = torch.rand_like(weight)
bias.grad = torch.rand_like(bias)
optimizer_base.step()
# col parallel step
cp_dist_optim.zero_grad()
weight_col_shard_flatten.grad = (
distribute_tensor(weight.grad, get_device_mesh(weight_col_shard), weight_col_shard_shard_spec)
.clone()
.flatten()
)
bias_col_flatten.grad = bias.grad.clone().flatten()
cp_dist_optim.step()
# row parallel step
rp_dist_optim.zero_grad()
weight_row_shard_flatten.grad = (
distribute_tensor(weight.grad, get_device_mesh(weight_row_shard), weight_row_shard_shard_spec)
.clone()
.flatten()
)
bias_row_flatten.grad = bias.grad.clone().flatten()
rp_dist_optim.step()
# gather result
weight_col_gather = _gather(
input_=weight_col_shard_flatten.data.view(-1, H // tp_size),
dim=-1,
process_group=tp_group,
) # gather
weight_row_gather = _gather(input_=weight_row_shard_flatten.data, dim=-1, process_group=tp_group).view(
-1, W
) # gather
# verify
correctness_verify(weight.data, weight_col_gather.data, dtype)
correctness_verify(weight.data, weight_row_gather.data, dtype)
print(f"Base Test Passed")
@parameterize("dtype", [torch.float16]) # torch.float32, torch.float16, torch.bfloat16
@parameterize("tp_zero_size", [(1, 4)]) # (2, 2), (4, 1), (1, 4)
def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
tp_size, zero_size = tp_zero_size
use_zero = True if zero_size > 1 else False
local_rank = dist.get_rank()
clear_layout_converter()
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1)
torch.set_default_dtype(dtype)
set_seed(42)
# ==============================
# Model Init
# ==============================
base_model = MlpModel().to(local_rank)
tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
base_param_group = setup_param_groups(base_model)
tp_param_group = setup_param_groups(tp_model)
tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
# ==============================
# Optimizer Init
# ==============================
base_optim = Adafactor(base_param_group)
dist_optim = DistributedAdaFactor(tp_param_group)
# Setup distributed optimizer
if zero_size > 1:
base_optim = LowLevelZeroOptimizer(
base_optim,
overlap_communication=True,
initial_scale=128,
partition_grad=True,
dp_process_group=dp_group,
verbose=True,
)
dist_optim = LowLevelZeroOptimizer(
dist_optim,
overlap_communication=True,
initial_scale=128,
partition_grad=True,
dp_process_group=dp_group,
verbose=True,
)
shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
dist_optim.optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param,
use_zero=use_zero,
)
else:
shard_to_param = set_master_param_to_shard_param(tp_param_group)
dist_optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param,
use_zero=use_zero,
)
# ==============================
# Correctness Verify
# ==============================
x = torch.randn(HEIGHT, WIDTH, device=local_rank)
out = base_model(x)
out_tp = tp_model(x)
if zero_size > 1:
dist_optim.backward(out_tp.sum())
base_optim.backward(out.sum())
else:
out_tp.sum().backward()
out.sum().backward()
base_optim.step()
dist_optim.step()
base_optim.zero_grad()
dist_optim.zero_grad()
for p, tp_p in zip(base_param_group, tp_param_group):
param_is_distributed = is_distributed_tensor(tp_p)
if param_is_distributed:
shard_spec = get_sharding_spec(tp_p)
if len(shard_spec.sharding_sequence) >= 2:
# Col Parallel
if shard_spec.sharding_sequence[0] == "R":
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
# ROW Parallel
if shard_spec.sharding_sequence[-1] == "R":
tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather
else:
# TP bias
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
else:
# No TP bias
pass
correctness_verify(p.data, tp_p.data, dtype)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Zero Test Passed")
@parameterize("dtype", [torch.float16])
@parameterize("tp_zero_size", [(1, 4)])
def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
tp_size, zero_size = tp_zero_size
use_zero = True if zero_size > 1 else False
local_rank = dist.get_rank()
clear_layout_converter()
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1)
torch.set_default_dtype(dtype)
set_seed(42)
# ==============================
# Model Init
# ==============================
base_model = MlpModel().to(local_rank)
# tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
tp_model = copy.deepcopy(base_model).to(local_rank)
base_param_group = setup_param_groups(base_model)
tp_param_group = setup_param_groups(tp_model)
tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
# ==============================
# Optimizer Init
# ==============================
base_optim = Adafactor(base_param_group)
dist_optim = DistributedAdaFactor(tp_param_group)
# Setup distributed optimizer
if zero_size > 1:
base_optim = LowLevelZeroOptimizer(
base_optim,
overlap_communication=True,
initial_scale=128,
partition_grad=True,
dp_process_group=dp_group,
verbose=True,
)
dist_optim = LowLevelZeroOptimizer(
dist_optim,
overlap_communication=True,
initial_scale=128,
partition_grad=True,
dp_process_group=dp_group,
verbose=True,
)
shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
dist_optim.optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param,
use_zero=use_zero,
)
else:
shard_to_param = set_master_param_to_shard_param(tp_param_group)
dist_optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param,
use_zero=use_zero,
)
# ==============================
# Booster Init
# ==============================
plugin = LowLevelZeroPlugin()
booster = Booster(plugin=plugin)
criterion = lambda x: x.mean()
tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion)
# ==============================
# Correctness Verify
# ==============================
x = torch.randn(HEIGHT, WIDTH, device=local_rank)
out = base_model(x)
out_tp = tp_model(x)
if zero_size > 1:
dist_optim.backward(out_tp.sum())
base_optim.backward(out.sum())
else:
out_tp.sum().backward()
out.sum().backward()
base_optim.step()
dist_optim.step()
base_optim.zero_grad()
dist_optim.zero_grad()
for p, tp_p in zip(base_param_group, tp_param_group):
param_is_distributed = is_distributed_tensor(tp_p)
if param_is_distributed:
shard_spec = get_sharding_spec(tp_p)
if len(shard_spec.sharding_sequence) >= 2:
# Col Parallel
if shard_spec.sharding_sequence[0] == "R":
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
# ROW Parallel
if shard_spec.sharding_sequence[-1] == "R":
tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather
else:
# TP bias
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
else:
# No TP bias
pass
correctness_verify(p.data, tp_p.data, dtype)
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Booster Test Passed")
@parameterize(
"test_config",
[
{
"stage": 1,
"precision": "bf16",
},
{
"stage": 2,
"precision": "bf16",
},
],
)
def exam_bert_test_on_lowlevelzero_plugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
model_list = [
"transformers_bert",
"transformers_bert_for_pretraining",
"transformers_bert_lm_head_model",
"transformers_bert_for_masked_lm",
"transformers_bert_for_sequence_classification",
"transformers_bert_for_token_classification",
"transformers_bert_for_next_sentence",
"transformers_bert_for_mcq",
"transformers_bert_for_question_answering",
]
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
(
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
# LowLevelZero not need warp
# bert = unwrap_model(org_model, "BertModel", "bert")
# sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
weight_layer_for_check = [
"bert.encoder.layer.0.output.dense.weight",
"bert.encoder.layer.0.output.dense.weight",
]
org_optimizer.step()
sharded_optimizer.step()
# check weights
if test_config["precision"] == "bf16":
atol, rtol = 5e-4, 5e-4
else:
atol, rtol = 5e-4, 5e-4
check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol)
check_optim_states(org_optimizer, sharded_optimizer.optim)
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Bert Model Zoo Test Passed")
@parameterize(
"test_config",
[
{
"tp_size": 1,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 2,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 4,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
},
# @duanjunwen TODO: fix this test case. Currently params are sharded but are not dtensor here, throwing an error.
# Probably due to HybridParallelAMPOptimizer replacing some master params ?
# {
# "tp_size": 4,
# "num_microbatches": 4,
# "zero_stage": 0,
# "precision": "bf16",
# },
],
)
def exam_bert_test_on_hybrid_plugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**16 # avoid overflow
model_list = [
"transformers_bert",
"transformers_bert_for_pretraining",
"transformers_bert_lm_head_model",
"transformers_bert_for_masked_lm",
"transformers_bert_for_sequence_classification",
"transformers_bert_for_token_classification",
"transformers_bert_for_next_sentence",
"transformers_bert_for_mcq",
"transformers_bert_for_question_answering",
]
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
(
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
org_optimizer.step()
sharded_optimizer.step()
# check weights
if test_config["precision"] == "bf16":
atol, rtol = 5e-4, 5e-4
else:
atol, rtol = 5e-4, 5e-4
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check optim states
check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Bert Model Zoo Test Passed")
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_bert_test_on_lowlevelzero_plugin()
exam_bert_test_on_hybrid_plugin()
exam_dist_adafactor_base()
exam_dist_adafactor_zero()
exam_dist_adafactor_booster()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_adafactor():
spawn(run_dist, nprocs=4)
if __name__ == "__main__":
test_dist_adafactor()

View File

@ -0,0 +1,475 @@
import copy
import pytest
import torch
import torch.distributed as dist
from torch import nn
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer.came import CAME
from colossalai.nn.optimizer.distributed_came import DistributedCAME
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer._operation import _gather
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor import get_layout, get_sharding_spec, is_distributed_tensor
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
from tests.test_optimizer._utils import check_dist_grad, check_dist_optim_state, check_dist_param, check_optim_states
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
build_model_from_low_level_zero_plugin,
run_forward_backward_with_hybrid_plugin,
run_forward_backward_with_low_level_zero_plugin,
unwrap_model,
)
HEIGHT = 128
WIDTH = 128
_TP_SPEC = DimSpec([0])
_SEED = 0
def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float32:
rtol = 5e-04
atol = 5e-04
elif dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
# return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol))
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
# setup param groups; (For zero test optim)
def setup_param_groups_zero(model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
# setup param groups; (For base optim)
def setup_param_groups(model: nn.Module) -> list:
optimizer_grouped_parameters = [p for n, p in model.named_parameters()]
return optimizer_grouped_parameters
# setup flatten param groups, sharding spec and shape; (For dist optim)
def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict:
flatten_optimizer_grouped_parameters = []
sharding_spec = {} # {id(flatten param): get_layout(p).global_shape}
param_shape = {} # {id(flatten param): get_sharding_spec(p)}
for n, p in model.named_parameters():
flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True))
flatten_optimizer_grouped_parameters.append(flatten_p)
if is_distributed_tensor(p):
sharding_spec[id(flatten_p)] = get_sharding_spec(p)
param_shape[id(flatten_p)] = get_layout(p).global_shape
else:
sharding_spec[id(flatten_p)] = None
param_shape[id(flatten_p)] = p.shape
return flatten_optimizer_grouped_parameters, sharding_spec, param_shape
def set_dist_grad(
dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup
) -> None:
"""
Set split grads for Tensor Parallel or ZeRO DP.
We do not need a separate treatment for ZeRO,
as the wrapper takes care of reduce-scattering grads.
"""
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
if torch_p.grad is None:
torch_p.grad = torch.zeros_like(torch_p)
is_distributed = hasattr(p, "dist_layout")
if is_distributed:
sharding = p.dist_layout.sharding_spec.sharding_sequence
split_dim = sharding.index(_TP_SPEC)
shape = torch_p.split(world_size, dim=split_dim)[rank].shape
indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1))
# Generate grads only for the correctly split chunk
torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype))
else:
shape = torch_p.shape
torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype)
# avoid inconsistent grad and param dtype error
orig_p = p.data
p.data = torch_p.grad.clone().to(g_dtype)
p.grad = p.data
p.data = orig_p
def set_master_param_to_shard_param(master_param_list) -> dict:
master_param_to_shard_param = {id(p): p for p in master_param_list}
return master_param_to_shard_param
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(HEIGHT, WIDTH)
self.linear2 = nn.Linear(WIDTH, HEIGHT)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
class TPModel(nn.Module):
def __init__(self, linear1, linear2, tp_group=None):
super().__init__()
self.linear1 = Linear1D_Col.from_native_module(
linear1, process_group=tp_group, gather_output=False, overlap=True
)
self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (4, 1), (1, 4)
def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
tp_size, zero_size = tp_zero_size
use_zero = True if zero_size > 1 else False
local_rank = dist.get_rank()
clear_layout_converter()
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1)
torch.set_default_dtype(dtype)
# set_seed(42)
# ==============================
# Model Init
# ==============================
base_model = MlpModel().to(local_rank)
tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
base_param_group = setup_param_groups(base_model)
tp_param_group = setup_param_groups(tp_model)
tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
# ==============================
# Optimizer Init
# ==============================
base_optim = CAME(base_param_group, lr=1e-3)
dist_optim = DistributedCAME(tp_param_group, lr=1e-3)
# Setup distributed optimizer
if zero_size > 1:
dist_optim = LowLevelZeroOptimizer(
dist_optim,
overlap_communication=True,
initial_scale=128,
partition_grad=True,
dp_process_group=dp_group,
verbose=True,
)
shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
dist_optim.optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param,
use_zero=use_zero,
)
else:
shard_to_param = set_master_param_to_shard_param(tp_param_group)
dist_optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param,
use_zero=use_zero,
)
# ==============================
# Correctness Verify
# ==============================
seed_all(1024)
x = torch.randn(WIDTH, HEIGHT, device=local_rank)
out = base_model(x)
out_tp = tp_model(x)
if zero_size > 1:
dist_optim.backward(out_tp.sum())
out.sum().backward()
else:
out_tp.sum().backward()
out.sum().backward()
base_optim.step()
dist_optim.step()
base_optim.zero_grad()
dist_optim.zero_grad()
for p, tp_p in zip(base_param_group, tp_param_group):
param_is_distributed = is_distributed_tensor(tp_p)
if param_is_distributed:
shard_spec = get_sharding_spec(tp_p)
if len(shard_spec.sharding_sequence) >= 2:
# Col Parallel
if shard_spec.sharding_sequence[0] == "R":
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
# ROW Parallel
if shard_spec.sharding_sequence[-1] == "R":
tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather
else:
# TP bias
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
else:
# No TP bias
pass
correctness_verify(p.data, tp_p.data, dtype)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Fwd/Bwd Test Passed")
@parameterize(
"test_config",
[
{
"stage": 1,
"precision": "bf16",
},
{
"stage": 2,
"precision": "bf16",
},
],
)
def exam_bert_test_on_lowlevelzero_plugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["initial_scale"] = 2**10
# check weights
if test_config["precision"] == "bf16":
atol, rtol = 5e-4, 5e-4
else:
atol, rtol = 5e-4, 5e-4
# test_config["initial_scale"] = 1
model_list = [
"transformers_bert",
"transformers_bert_for_pretraining",
"transformers_bert_lm_head_model",
"transformers_bert_for_masked_lm",
"transformers_bert_for_sequence_classification",
"transformers_bert_for_token_classification",
"transformers_bert_for_next_sentence",
"transformers_bert_for_mcq",
"transformers_bert_for_question_answering",
"simple_mlp",
]
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
seed_all(_SEED)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
(
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, CAME, DistributedCAME)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
# assert same output
# assert_close(org_output, org_output, atol=atol, rtol=rtol)
weight_layer_for_check = [
"bert.encoder.layer.1.intermediate.dense",
# TODO: error in layer:
# "bert.encoder.layer.0.output.dense",
# "bert.encoder.layer.1.output.dense",
]
# assert same weight before step; pass
check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol)
# asserr loss; pass
assert_close(org_loss, sharded_loss)
# assert same grad before step
# TODO: err here; backward diff gard; Only transformers_bert pass;
check_dist_grad(sharded_optimizer, org_model, sharded_model, weight_layer_for_check, atol, rtol)
org_optimizer.step()
sharded_optimizer.step()
# assert same weight after step
check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol)
check_optim_states(org_optimizer, sharded_optimizer.optim)
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"LowLevelZeroPlugin + Bert Model Zoo Test Passed")
@parameterize(
"test_config",
[
{
"tp_size": 1,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 2,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 4,
"num_microbatches": 4,
"zero_stage": 2,
"precision": "bf16",
},
{
"tp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
},
{
"tp_size": 4,
"num_microbatches": 4,
"zero_stage": 0,
"precision": "bf16",
},
],
)
def exam_bert_test_on_hybrid_plugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**16 # avoid overflow
model_list = [
"transformers_bert",
"transformers_bert_for_pretraining",
"transformers_bert_lm_head_model",
"transformers_bert_for_masked_lm",
"transformers_bert_for_sequence_classification",
"transformers_bert_for_token_classification",
"transformers_bert_for_next_sentence",
"transformers_bert_for_mcq",
"transformers_bert_for_question_answering",
]
# pass "transformers_bert",
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
# check weights
if test_config["precision"] == "bf16":
atol, rtol = 5e-3, 5e-3
else:
atol, rtol = 5e-3, 5e-3
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
(
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, DistributedCAME)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
booster.plugin.tp_group
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
# TODO: model
# "encoder.layer.0.output.dense.weight", "encoder.layer.1.output.dense.weight" not match
# "encoder.layer[0].output.dense", "encoder.layer[1].output.dense" not match
weight_layer_for_check = ["embeddings.word_embeddings"] # [30522, 128]
# # assert same weight before step; all pass
# check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol)
# # assert loss; all pass
# assert_close(org_loss, sharded_loss)
# # assert same grad before step; all pass
# check_dist_grad(org_model, sharded_model, weight_layer_for_check, atol, rtol)
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_dist_param(bert, sharded_bert, weight_layer_for_check, atol, rtol)
# check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check optim states
check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"HybridParallelPlugin + Bert Model Zoo Test Passed")
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_bert_test_on_lowlevelzero_plugin() # err in TODO layer
exam_bert_test_on_hybrid_plugin() # pass
exam_dist_came_base() # pass
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_came():
spawn(run_dist, nprocs=4)
if __name__ == "__main__":
test_dist_came()

View File

@ -0,0 +1,336 @@
"""Usage(requires 4 GPUs): python test_dist_galore.py"""
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import DistGaloreAwamW, GaLoreAdamW8bit
from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
from tests.test_optimizer._utils import check_optim_states, run_bert_test
_ALLOWED_P_G_TYPES = [
(torch.float, torch.float), # pure fp32
(torch.half, torch.half), # fp16 amp
(torch.bfloat16, torch.bfloat16), # bfloat16 amp
]
# Identifiers for Tensor Parallel linear layers
_IN_DIM = 32
_HID_DIM = 128
_N_STEP = 3
_SEED = 0
coordinator = None
lr = 1e-2
beta1, beta2 = 0.9, 0.999
eps = 1e-8
decay = 1e-3
Net, data_gen, *_ = next(iter(model_zoo.get_sub_registry("simple_mlp").values()))
TPNet, *_ = next(iter(model_zoo.get_sub_registry("simple_tp_mlp").values()))
# Doesn't support ZeRO for now
test_config = [
{
"tp_size": 1,
"num_microbatches": 4,
"zero_stage": 0,
"precision": "bf16",
},
{
"tp_size": 2,
"num_microbatches": 4,
"zero_stage": 0,
"precision": "bf16",
},
{
"tp_size": 4,
"num_microbatches": 4,
"zero_stage": 0,
"precision": "bf16",
},
]
def assert_grad_close(tp_model, torch_model, tp_group):
tp_size = dist.get_world_size(tp_group)
# Check equal grads
for p, torch_p in zip(tp_model.parameters(), torch_model.parameters()):
grads = p.grad
if is_distributed_tensor(p):
split_dim = get_shard_dim_1d(p)
all_grads = [torch.empty_like(grads) for _ in range(tp_size)]
dist.all_gather(all_grads, grads.contiguous(), group=tp_group)
all_grads = torch.cat(all_grads, dim=split_dim)
else:
all_grads = grads
try:
assert (all_grads != 0).any()
assert_close(all_grads, torch_p.grad)
except Exception as e:
print(f"Before gather: {grads.shape}, after: {all_grads.shape}")
raise e
def assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group):
rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
for (name, p), torch_p in zip(tp_model.named_parameters(), torch_model.parameters()):
# if overflow, the weight won't be updated. so there will be no nan in p
assert not torch.isnan(p).any()
try:
if is_distributed_tensor(p):
split_dim = get_shard_dim_1d(p)
torch_p = torch_p.chunk(tp_size, dim=split_dim)[rank]
assert_close(p, torch_p, rtol=rtol, atol=atol)
except AssertionError as e:
print(f"grad mismatch in {name}")
raise e
def force_assign_grad(p, g_dtype, grad=None):
"""avoid inconsistent grad and param dtype error"""
orig_p = p.data
p.data = torch.randn_like(p, device=orig_p.device, dtype=g_dtype) if grad == None else grad
p.grad = p.data
p.data = orig_p
def set_dist_grad(
dist_module: nn.Module,
torch_model: nn.Module,
g_dtype: torch.dtype,
group: dist.ProcessGroup,
) -> None:
"""
Set grads chunks for Tensor Parallel or ZeRO DP.
We do not need a separate treatment for ZeRO,
as the LowLevelOptimizer takes care of reduce-scattering grads.
"""
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
if torch_p.grad is None:
# avoid inconsistent grad and param dtype error
force_assign_grad(torch_p, g_dtype)
else:
torch_p.grad += torch.randn_like(torch_p, device=torch_p.device, dtype=g_dtype)
if p.grad is None:
force_assign_grad(p, g_dtype)
if is_distributed_tensor(p):
split_dim = get_shard_dim_1d(p)
# Add grads only to the correctly split chunk
force_assign_grad(p, g_dtype, torch_p.grad.chunk(world_size, dim=split_dim)[rank].contiguous())
# assert_close(p.grad, torch_p.grad.chunk(world_size, dim=split_dim)[rank])
else:
force_assign_grad(p, g_dtype, torch_p.grad)
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
@parameterize("tp_zero_size", [(4, 1), (1, 4), (2, 2)])
def run_dist_galore_basic(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]) -> None:
"""Test without forward"""
p_dtype, g_dtype = p_g_dtype
tp_size, zero_size = tp_zero_size
# Set distributed groups
rank = dist.get_rank()
clear_layout_converter() # Ensure correct sharding
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
tp_group = proc_mesh.get_group_along_axis(0)
dp_group = proc_mesh.get_group_along_axis(1)
dist.get_rank(tp_group)
seed_all(_SEED) # Fix model init
torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True, dtype=p_dtype).to(rank)
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank)
assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group)
# Set up optimizers
torch_optim = GaLoreAdamW8bit(
get_galore_param_groups(torch_model, decay, rank=8),
lr=lr,
betas=(beta1, beta2),
eps=eps,
percentile_clipping=101,
block_wise=False,
min_8bit_size=1e10, # Disable quantization
)
optim = DistGaloreAwamW(
get_galore_param_groups(tp_model, decay, rank=8),
lr=lr,
betas=(beta1, beta2),
eps=eps,
percentile_clipping=101,
block_wise=False,
min_8bit_size=1e10,
)
optim.setup_distributed(tp_group, dp_group)
rtol, atol = 8e-7, 8e-7
if p_dtype is torch.float16 or g_dtype is torch.float16:
rtol, atol = 1e-6, 1e-6
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
rtol, atol = 2e-6, 2e-6
for i in range(_N_STEP):
seed_all(_SEED + i) # NOTE: having only one manual_seed above doesn't work?
set_dist_grad(tp_model, torch_model, g_dtype, tp_group)
try:
torch_optim.step()
optim.step()
assert_grad_close(tp_model, torch_model, tp_group)
torch_optim.zero_grad()
optim.zero_grad()
assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group)
check_optim_states(torch_optim, optim)
except Exception as e:
coordinator.print_on_master(f"step {i}: p_g_dtype: {p_g_dtype}, tp_zero_size: {tp_zero_size}")
raise e
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
@parameterize("tp_zero_size", [(4, 1), (2, 2), (1, 4)])
def run_dist_galore_fwd_bwd(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]) -> None:
p_dtype, g_dtype = p_g_dtype
tp_size, zero_size = tp_zero_size
# Set distributed groups
rank = dist.get_rank()
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
tp_group = proc_mesh.get_group_along_axis(0)
dp_group = proc_mesh.get_group_along_axis(1)
dist.get_rank(tp_group)
seed_all(_SEED)
clear_layout_converter() # Ensure correct sharding
torch_model = Net(_IN_DIM, _HID_DIM, identity=True, dtype=p_dtype).to(rank)
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank)
assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group)
# Set up optimizers
torch_optim = GaLoreAdamW8bit(
get_galore_param_groups(torch_model, decay, rank=8),
lr=lr,
betas=(beta1, beta2),
eps=eps,
percentile_clipping=101,
block_wise=False,
min_8bit_size=1e10,
)
optim = DistGaloreAwamW(
get_galore_param_groups(tp_model, decay, rank=8),
lr=lr,
betas=(beta1, beta2),
eps=eps,
percentile_clipping=101,
block_wise=False,
min_8bit_size=1e10,
)
# Setup distributed optimizer
if zero_size > 1:
optim = LowLevelZeroOptimizer(
optim,
overlap_communication=True,
initial_scale=128,
partition_grad=True,
dp_process_group=dp_group,
verbose=True,
)
shard_to_param = optim.get_master_to_working_map()
optim.optim.setup_distributed(
tp_group, dp_group, shard_to_param, padding_map=optim.get_param_padding_map(), is_zero=True
)
else:
optim.setup_distributed(tp_group)
rtol, atol = 8e-7, 8e-7
if p_dtype is torch.float16 or g_dtype is torch.float16:
rtol, atol = 1e-6, 1e-6
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
rtol, atol = 2e-6, 2e-6
seed_all(_SEED) # NOTE: having only one manual_seed above doesn't work?
x = data_gen().cuda().to(dtype=p_dtype)
out_tp = tp_model(x)
out = torch_model(x)
try:
assert_close(out, out_tp, rtol=rtol, atol=atol)
except Exception as e:
coordinator.print_on_master(f"p_g_dtype: {p_g_dtype}, tp_zero_size: {tp_zero_size}")
raise e
if zero_size > 1:
optim.backward(out_tp.sum())
out.sum().backward()
else:
out_tp.sum().backward()
out.sum().backward()
torch_optim.step()
optim.step()
torch_optim.zero_grad()
optim.zero_grad()
try:
assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group)
check_optim_states(getattr(torch_optim, "optim", torch_optim), getattr(optim, "optim", optim))
except Exception as e:
coordinator.print_on_master(f"p_g_dtype: {p_g_dtype}, tp_zero_size: {tp_zero_size}")
raise e
def check_dist_galore(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
global coordinator
coordinator = DistCoordinator()
run_dist_galore_basic()
coordinator.print_on_master("Basic backward tests passed")
coordinator.print_on_master("Skipping forward-backward tests due to SVD instability")
# run_dist_galore_fwd_bwd()
# _COORDINATOR.print_on_master("Forward-backward tests passed")
coordinator.print_on_master(
"Running bert tests, which are expected to produce minor errors due to instability in SVD convergence. \
For example, a 1e-9 grad diff causes drastic difference in SVD output."
)
for config in test_config:
try:
run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=DistGaloreAwamW)
except Exception as e:
print(e)
dist.barrier()
print(f"rank {rank} tests passed :)")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_galore():
spawn(check_dist_galore, nprocs=4)
if __name__ == "__main__":
test_dist_galore()

View File

@ -0,0 +1,303 @@
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import DistributedLamb, Lamb
from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
from tests.test_optimizer._utils import check_optim_states, run_bert_test
_ALLOWED_P_G_TYPES = [
(torch.float, torch.float), # pure fp32
(torch.float, torch.half), # fp16 amp
(torch.float, torch.bfloat16), # bfloat16 amp
]
_IN_DIM = 32
_HID_DIM = 128
_N_STEP = 3
_SEED = 1024
coordinator = None
Net, data_gen, *_ = next(iter(model_zoo.get_sub_registry("simple_mlp").values()))
TPNet, *_ = next(iter(model_zoo.get_sub_registry("simple_tp_mlp").values()))
def assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group):
rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
for (name, p), torch_p in zip(tp_model.named_parameters(), torch_model.parameters()):
# if overflow, the weight won't be updated. so there will be no nan in p
assert not torch.isnan(p).any()
try:
if is_distributed_tensor(p):
split_dim = get_shard_dim_1d(p)
torch_p = torch_p.chunk(tp_size, dim=split_dim)[rank]
assert_close(p.float(), torch_p, rtol=rtol, atol=atol)
except AssertionError as e:
print(f"grad mismatch in {name}")
raise e
def setup_param_groups(bert_model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
def force_assign_grad(p, g_dtype, grad=None):
"""avoid inconsistent grad and param dtype error"""
orig_p = p.data
p.data = torch.randn_like(p, device=orig_p.device, dtype=g_dtype) if grad == None else grad
p.grad = p.data
p.data = orig_p
def set_dist_grad(
dist_module: nn.Module,
torch_model: nn.Module,
g_dtype: torch.dtype,
group: dist.ProcessGroup,
) -> None:
"""
Set grads chunks for Tensor Parallel or ZeRO DP.
We do not need a separate treatment for ZeRO,
as the LowLevelOptimizer takes care of reduce-scattering grads.
"""
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
if torch_p.grad is None:
# avoid inconsistent grad and param dtype error
force_assign_grad(torch_p, g_dtype)
else:
torch_p.grad += torch.randn_like(torch_p, device=torch_p.device, dtype=g_dtype)
if p.grad is None:
force_assign_grad(p, g_dtype)
if is_distributed_tensor(p):
split_dim = get_shard_dim_1d(p)
# Add grads only to the correctly split chunk
force_assign_grad(p, g_dtype, torch_p.grad.chunk(world_size, dim=split_dim)[rank])
# assert_close(p.grad, torch_p.grad.chunk(world_size, dim=split_dim)[rank])
else:
force_assign_grad(p, g_dtype, torch_p.grad)
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
@parameterize("bias_correction", [False, True])
@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)])
def run_dist_lamb_basic(
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
) -> None:
"""Test without forward"""
p_dtype, g_dtype = p_g_dtype
tp_size, zero_size = tp_zero_size
# Set distributed groups
rank = dist.get_rank()
clear_layout_converter() # Ensure correct sharding
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
tp_group = proc_mesh.get_group_along_axis(0)
tp_rank = dist.get_rank(tp_group)
seed_all(_SEED) # Fix model init
torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True).to(rank)
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group).to(rank)
# Ensure equal weight init
assert_close(
torch_model.fc1.weight[tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size],
tp_model.fc1.weight,
)
assert_close(
torch_model.fc2.weight[:, tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size],
tp_model.fc2.weight,
)
# Set up optimizers
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
torch_optim = Lamb(
setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps, bias_correction=bias_correction
)
optim = DistributedLamb(
setup_param_groups(tp_model),
lr=lr,
betas=(beta1, beta2),
eps=eps,
bias_correction=bias_correction,
)
optim.setup_distributed(tp_group)
rtol, atol = 8e-7, 8e-7
if p_dtype is torch.float16 or g_dtype is torch.float16:
rtol, atol = 1e-6, 1e-6
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
rtol, atol = 2e-6, 2e-6
for i in range(_N_STEP):
seed_all(_SEED + i) # NOTE: having only one manual_seed above doesn't work?
set_dist_grad(tp_model, torch_model, g_dtype, tp_group)
torch_optim.step()
optim.step()
torch_optim.zero_grad()
optim.zero_grad()
try:
assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group)
except Exception as e:
coordinator.print_on_master(
f"step {i + 1}: bias_correction: {bias_correction}, p_g_dtype: {p_g_dtype}, tp_zero_size: {tp_zero_size}"
)
raise e
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
@parameterize("bias_correction", [False, True])
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)])
def run_dist_lamb_fwd_bwd(
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
) -> None:
p_dtype, g_dtype = p_g_dtype
tp_size, zero_size = tp_zero_size
# Set distributed groups
rank = dist.get_rank()
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
tp_group = proc_mesh.get_group_along_axis(0)
dp_group = proc_mesh.get_group_along_axis(1)
tp_rank = dist.get_rank(tp_group)
seed_all(_SEED)
clear_layout_converter() # Ensure correct sharding
torch_model = Net(_IN_DIM, _HID_DIM).to(rank)
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group).to(rank)
assert_close(
torch_model.fc1.weight[tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size],
tp_model.fc1.weight,
)
assert_close(
torch_model.fc2.weight[:, tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size],
tp_model.fc2.weight,
)
# Set up optimizers
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
torch_optim = Lamb(
setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps, bias_correction=bias_correction
)
optim = DistributedLamb(
setup_param_groups(tp_model),
lr=lr,
betas=(beta1, beta2),
eps=eps,
bias_correction=bias_correction,
)
# Setup distributed optimizer
if zero_size > 1:
optim = LowLevelZeroOptimizer(
optim,
overlap_communication=True,
initial_scale=128,
partition_grad=True,
dp_process_group=dp_group,
verbose=True,
)
shard_to_param = optim._param_store.master_to_working_param
optim.optim.setup_distributed(tp_group, dp_group, shard_to_param, is_zero=True)
else:
optim.setup_distributed(tp_group)
rtol, atol = 8e-7, 8e-7
if p_dtype is torch.float16 or g_dtype is torch.float16:
rtol, atol = 1e-6, 1e-6
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
rtol, atol = 2e-6, 2e-6
seed_all(_SEED) # NOTE: having only one manual_seed above doesn't work?
x = data_gen()
x = x.cuda().to(dtype=p_dtype)
out_tp = tp_model(x)
out = torch_model(x)
try:
assert_close(out, out_tp, rtol=rtol, atol=atol)
except Exception as e:
coordinator.print_on_master(
f"bias_correction: {bias_correction}, p_g_dtype: {p_g_dtype}, tp_zero_size: {tp_zero_size}"
)
raise e
if zero_size > 1:
optim.backward(out_tp.sum())
out.sum().backward()
else:
out_tp.sum().backward()
out.sum().backward()
torch_optim.step()
optim.step()
dist.barrier()
torch_optim.zero_grad()
optim.zero_grad()
try:
assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group)
check_optim_states(getattr(torch_optim, "optim", torch_optim), getattr(optim, "optim", optim))
except Exception as e:
coordinator.print_on_master(
f"bias_correction: {bias_correction}, p_g_dtype: {p_g_dtype}, tp_zero_size: {tp_zero_size}"
)
raise e
def check_dist_lamb(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
global coordinator
coordinator = DistCoordinator()
run_dist_lamb_basic()
coordinator.print_on_master("Basic tests passed")
run_dist_lamb_fwd_bwd()
coordinator.print_on_master("Forward-backward tests passed")
run_bert_test(optim_class=Lamb, sharded_optim_class=DistributedLamb)
print(f"rank {rank} tests passed :)")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_lamb():
spawn(check_dist_lamb, nprocs=4)
if __name__ == "__main__":
test_dist_lamb()

View File

@ -11,11 +11,14 @@ from torch.nn import Module
from torch.optim import Adam, Optimizer
from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
@ -113,7 +116,9 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
assert torch.equal(v, shard_v), f"{name} {k} value mismatch"
def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]):
def build_model_from_hybrid_plugin(
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
):
use_lazy_init = False
if "use_lazy_init" in test_config:
use_lazy_init = test_config.pop("use_lazy_init")
@ -125,8 +130,25 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
if use_lazy_init:
ctx.materialize(org_model)
org_model = org_model.cuda()
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
if sharded_optim_class == DistGaloreAwamW:
# Disable clipping and block-wise quantization
org_optimizer = optim_class(
get_galore_param_groups(org_model, weight_decay=0, rank=4),
lr=1e-3,
percentile_clipping=101,
block_wise=False,
min_8bit_size=1e10,
)
sharded_optimizer = sharded_optim_class(
get_galore_param_groups(sharded_model, weight_decay=0, rank=4),
lr=1e-3,
percentile_clipping=101,
block_wise=False,
min_8bit_size=1e10,
)
else:
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
plugin = HybridParallelPlugin(**test_config)
@ -143,6 +165,32 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
)
def build_model_from_low_level_zero_plugin(
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
):
use_lazy_init = False
if "use_lazy_init" in test_config:
use_lazy_init = test_config.pop("use_lazy_init")
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
org_model = model_fn()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
org_model = org_model.cuda()
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
plugin = LowLevelZeroPlugin(**test_config)
booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
def run_forward_backward_with_hybrid_plugin(
org_model: Module,
sharded_model: Module,
@ -209,6 +257,44 @@ def run_forward_backward_with_hybrid_plugin(
return org_loss, org_output, sharded_loss, sharded_output
def run_forward_backward_with_low_level_zero_plugin(
org_model: Module,
sharded_model: Module,
sharded_optimizer: Optimizer,
data_gen_fn: Callable,
output_transform_fn: Callable,
criterion: Callable,
booster: Booster,
):
get_accelerator().get_current_device()
org_model.cuda()
sharded_model.cuda()
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
data = data_gen_fn()
# data = {
# k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
# }
data = {k: v.cuda() for k, v in data.items()}
sharded_model.train()
sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output)
sharded_optimizer.backward(sharded_loss)
org_model.train()
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
return org_loss, org_output, sharded_loss, sharded_output
def check_output_hidden_state(
org_output: Tensor,
sharded_output: Tensor,
@ -312,6 +398,9 @@ def check_grad(
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
# if verbose and dist.get_rank() == 0:
# print("shard_weight", shard_weight)
# print("org_grad", org_grad)
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)

View File

@ -64,7 +64,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = 0 if sharded_optimizer._partition_grads else sharded_optimizer._local_rank
grad_index = (
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
)
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)

View File

@ -0,0 +1,235 @@
import os
import pytest
import torch
import transformers
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
qwen2_model = unwrap_model(org_model, "Qwen2Model", "model")
shard_qwen2_model = unwrap_model(sharded_model, "Qwen2Model", "model")
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 1e-6, 1e-4
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
qwen2_model, shard_qwen2_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
col_layer_grads = get_grad_tensors_for_check(
qwen2_model, shard_qwen2_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "Qwen2Model":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(
qwen2_model, shard_qwen2_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_qwen2_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "interleaved",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
],
)
def run_qwen2_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_qwen2(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_qwen2_test()
def check_qwen2_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_qwen2_3d_test()
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_qwen2():
spawn(check_qwen2, 4)
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_qwen2_3d():
spawn(check_qwen2_3d, 8)
if __name__ == "__main__":
test_qwen2()
test_qwen2_3d()

View File

@ -26,7 +26,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list)
if not model.reuse_fp16_chunk:
if not model.chunk_manager.reuse_fp16_chunk:
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
for chunk in chunk_list:
chunk_manager.access_chunk(chunk)