Merge pull request #6107 from duanjunwen/dev/zero_bubble

[Zerobubble] Merge Main.
feature/zerobubble
duanjunwen 2024-11-05 11:31:48 +08:00 committed by GitHub
commit 37b23e32b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 1681 additions and 830 deletions

View File

@ -15,21 +15,21 @@ repos:
args: ["--profile", "black"] # avoid conflict with black args: ["--profile", "black"] # avoid conflict with black
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0 rev: 24.10.0
hooks: hooks:
- id: black - id: black
name: black formatter name: black formatter
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8 rev: v19.1.2
hooks: hooks:
- id: clang-format - id: clang-format
name: clang formatter name: clang formatter
types_or: [c++, c] types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0 rev: v5.0.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: check-merge-conflict - id: check-merge-conflict

View File

@ -25,16 +25,36 @@
</div> </div>
## GPU Cloud HPC-AI.COM Coming
For a limited time, you can access an H100 Server for just $1! This is your chance to leverage premium GPU power at an unbeatable price.
Plus, when you refer a friend, youll receive 20% cashback or compute credits equal to 100% of their top-up!
Our platform offers on-demand premium compute, ensuring safe, permanent data storage even after stopping your instance.
Dont miss this incredible opportunity to accelerate your AI projects!
Unlock premium GPUs and register now at [HPC-AI.COM](https://hpc-ai.com) to receive $10!
Special Bonuses:
* Top up $1,000 and receive 300 credits
* Top up $500 and receive 100 credits
<div align="center">
<a href="https://youtu.be/ilMQpU71ddI?si=J4JSPzZ03ycYmlki">
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/HPCAICOM241010.jpg" width="700" />
</a>
</div>
## Latest News ## Latest News
* [2024/10] [How to build a low-cost Sora-like app? Solutions for you](https://company.hpc-ai.com/blog/how-to-build-a-low-cost-sora-like-app-solutions-for-you)
* [2024/09] [Singapore Startup HPC-AI Tech Secures 50 Million USD in Series A Funding to Build the Video Generation AI Model and GPU Platform](https://company.hpc-ai.com/blog/singapore-startup-hpc-ai-tech-secures-50-million-usd-in-series-a-funding-to-build-the-video-generation-ai-model-and-gpu-platform)
* [2024/09] [Reducing AI Large Model Training Costs by 30% Requires Just a Single Line of Code From FP8 Mixed Precision Training Upgrades](https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades)
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use) * [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference) * [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) * [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series) * [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
* [2024/03] [Open-SoraSora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
## Table of Contents ## Table of Contents
<ul> <ul>

View File

@ -27,11 +27,11 @@
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo) - [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo) - [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto) - [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [O1 Journey](#o1-journey)
- [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
- [FAQ](#faq) - [FAQ](#faq)
- [How to save/load checkpoint](#faq) - [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq) - [How to train with limited resources](#faq)
- [The Plan](#the-plan)
- [Real-time progress](#real-time-progress)
- [Invitation to open-source contribution](#invitation-to-open-source-contribution) - [Invitation to open-source contribution](#invitation-to-open-source-contribution)
- [Quick Preview](#quick-preview) - [Quick Preview](#quick-preview)
- [Authors](#authors) - [Authors](#authors)
@ -272,7 +272,7 @@ Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pd
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO) ## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information. We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
### Inference Quantization and Serving - After Training ## Inference Quantization and Serving - After Training
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models. We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
@ -281,6 +281,21 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen
Online inference server scripts can help you deploy your own services. Online inference server scripts can help you deploy your own services.
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference). For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
## O1 Journey
### Inference with Self-refined MCTS
We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search.
To run inference with MCTS, simply use the following script.
```python
from coati.reasoner.guided_search.mcts import MCTS
from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG
problem = "How Many R in 'Strawberry'"
search_tree = MCTS(problem=problem, max_simulations=8, cfg=Qwen32B_prompt_CFG)
answer = search_tree.simulate()
print(answer)
```
## Coati7B examples ## Coati7B examples
### Generation ### Generation

View File

@ -153,10 +153,11 @@ class DpoLoss(nn.Module):
else: else:
# If no reference model is provided # If no reference model is provided
ref_logratios = 0.0 ref_logratios = 0.0
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1) pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
logits = pi_logratios - ref_logratios - self.gamma / self.beta logits = pi_logratios - ref_logratios - self.gamma / self.beta
losses = -torch.nn.functional.logsigmoid(self.beta * logits) losses = -torch.nn.functional.logsigmoid(self.beta * logits)
loss = losses.mean()
# Calculate rewards for logging # Calculate rewards for logging
if logprob_ref_chosen is not None: if logprob_ref_chosen is not None:
chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach() chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
@ -167,7 +168,7 @@ class DpoLoss(nn.Module):
else: else:
rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach() rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()
return losses, chosen_rewards, rejected_rewards return loss, chosen_rewards, rejected_rewards
class LogSigLoss(nn.Module): class LogSigLoss(nn.Module):

View File

@ -50,8 +50,8 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
torch.Tensor: The log probabilities corresponding to the labels. torch.Tensor: The log probabilities corresponding to the labels.
""" """
log_probs = F.log_softmax(logits, dim=-1) log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1) return per_label_logps.squeeze(-1)
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:

View File

@ -0,0 +1,26 @@
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
API_KEY = "Dummy API Key"
def get_client(base_url: str | None = None) -> openai.Client:
return openai.Client(api_key=API_KEY, base_url=base_url)
def chat_completion(
messages: list[ChatCompletionMessageParam],
model: str,
base_url: str | None = None,
temperature: float = 0.8,
**kwargs,
) -> ChatCompletion:
client = get_client(base_url)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
**kwargs,
)
return response

View File

@ -0,0 +1,250 @@
"""
Implementation of MCTS + Self-refine algorithm.
Reference:
1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
2. https://github.com/BrendanGraham14/mcts-llm/
3. https://github.com/trotsky1997/MathBlackBox/
4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
"""
from __future__ import annotations
import math
from collections import deque
import numpy as np
import tqdm
from coati.reasoner.guided_search.llm import chat_completion
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
from pydantic import BaseModel
class MCTSNode(BaseModel):
"""
Node for MCTS.
"""
answer: str
parent: MCTSNode = None
children: list[MCTSNode] = []
num_visits: int = 0
Q: int = 0
rewards: list[int] = []
def expand_node(self, node) -> None:
self.children.append(node)
def add_reward(self, reward: int) -> None:
self.rewards.append(reward)
self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2
class MCTS(BaseModel):
"""
Simulation of MCTS process.
"""
problem: str
max_simulations: int
cfg: PromptCFG
C: float = 1.4
max_children: int = 2
epsilon: float = 1e-5
root: MCTSNode = None
def initialization(self):
"""
Root Initiation.
"""
# Dummy answer as root.
base_answer = self.sample_base_answer()
self.root = MCTSNode(answer=base_answer)
self.self_evaluate(self.root)
def is_fully_expanded(self, node: MCTSNode):
return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children)
def select_node(self) -> MCTSNode:
"""
Select next node to explore.
"""
candidates: list[MCTSNode] = []
to_explore = deque([self.root])
while to_explore:
current_node = to_explore.popleft()
if not self.is_fully_expanded(current_node):
candidates.append(current_node)
to_explore.extend(current_node.children)
if not candidates:
return self.root
return max(candidates, key=self.compute_uct)
def self_evaluate(self, node: MCTSNode):
"""
Sample reward of the answer.
"""
reward = self.sample_reward(node)
node.add_reward(reward)
def back_propagation(self, node: MCTSNode):
"""
Back propagate the value of the refined answer.
"""
parent = node.parent
while parent:
best_child_Q = max(child.Q for child in parent.children)
parent.Q = (parent.Q + best_child_Q) / 2
parent.num_visits += 1
parent = parent.parent
def compute_uct(self, node: MCTSNode):
"""
Compute UCT.
"""
if node.parent is None:
return -100
return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon))
def simulate(self):
self.initialization()
for _ in tqdm.tqdm(range(self.max_simulations)):
node = self.select_node()
child = self.self_refine(node)
node.expand_node(child)
self.self_evaluate(child)
self.back_propagation(child)
return self.get_best_answer()
def get_best_answer(self):
to_visit = deque([self.root])
best_node = self.root
while to_visit:
current_node = to_visit.popleft()
if current_node.Q > best_node.Q:
best_node = current_node
to_visit.extend(current_node.children)
return best_node.answer
def self_refine(self, node: MCTSNode):
"""
Refine node.
"""
critique_response = chat_completion(
messages=[
{
"role": "system",
"content": self.cfg.critic_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<current_answer>\n{node.answer}\n</current_answer>",
]
),
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
critique = critique_response.choices[0].message.content
assert critique is not None
refined_answer_response = chat_completion(
messages=[
{
"role": "system",
"content": self.cfg.refine_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<current_answer>\n{node.answer}\n</current_answer>",
f"<critique>\n{critique}\n</critique>",
]
),
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
refined_answer = refined_answer_response.choices[0].message.content
assert refined_answer is not None
return MCTSNode(answer=refined_answer, parent=node)
def sample_base_answer(self):
response = chat_completion(
messages=[
{
"role": "system",
"content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].",
},
{
"role": "user",
"content": f"<problem>\n {self.problem} \n</problem> \nLet's think step by step",
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
assert response.choices[0].message.content is not None
return response.choices[0].message.content
def sample_reward(self, node: MCTSNode):
"""
Calculate reward.
"""
messages = [
{
"role": "system",
"content": self.cfg.evaluate_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<answer>\n{node.answer}\n</answer>",
]
),
},
]
for attempt in range(3):
try:
response = chat_completion(
messages=messages,
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
assert response.choices[0].message.content is not None
return int(response.choices[0].message.content)
except ValueError:
messages.extend(
[
{
"role": "assistant",
"content": response.choices[0].message.content,
},
{
"role": "user",
"content": "Failed to parse reward as an integer.",
},
]
)
if attempt == 2:
raise

View File

@ -0,0 +1,10 @@
from pydantic import BaseModel
class PromptCFG(BaseModel):
model: str
base_url: str
max_tokens: int = 4096
critic_system_prompt: str
refine_system_prompt: str
evaluate_system_prompt: str

View File

@ -0,0 +1,20 @@
"""
Prompts for Qwen Series.
"""
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
Qwen32B_prompt_CFG = PromptCFG(
base_url="http://0.0.0.0:8008/v1",
model="Qwen2.5-32B-Instruct",
critic_system_prompt="Provide a detailed and constructive critique to improve the answer. "
"Highlight specific areas that need refinement or correction.",
refine_system_prompt="""# Instruction
Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer].
""",
evaluate_system_prompt=(
"Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. "
"Do not give a full score above 95. Make sure the reward score is an integer. "
"Return *ONLY* the score."
),
)

View File

@ -6,6 +6,7 @@ import os
from typing import Any, Optional from typing import Any, Optional
import torch import torch
import torch.distributed as dist
from coati.models.loss import DpoLoss from coati.models.loss import DpoLoss
from coati.models.utils import calc_masked_log_probs from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean
@ -13,10 +14,11 @@ from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import trange from tqdm import tqdm, trange
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster, Plugin from colossalai.booster import Booster, Plugin
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -96,18 +98,25 @@ class DPOTrainer(SLTrainer):
self.train_dataloader = train_preference_dataloader self.train_dataloader = train_preference_dataloader
self.eval_dataloader = eval_preference_dataloader self.eval_dataloader = eval_preference_dataloader
self.writer = None self.writer = None
if use_wandb and is_rank_0():
init_criterion = (
dist.get_rank() == dist.get_world_size() - 1
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1
else is_rank_0()
)
if use_wandb and init_criterion:
assert log_dir is not None, "log_dir must be provided when use_wandb is True" assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb import wandb
self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True) self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True)
if log_dir is not None and is_rank_0(): if log_dir is not None and init_criterion:
import os import os
import time import time
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "dpo") log_dir = os.path.join(log_dir, "DPO")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir) self.writer = SummaryWriter(log_dir=log_dir)
@ -117,166 +126,147 @@ class DPOTrainer(SLTrainer):
epoch int: the number of current epoch epoch int: the number of current epoch
""" """
self.model.train() self.model.train()
self.accumulative_meter.reset() if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
step_bar = trange( step_bar = tqdm(
len(self.train_dataloader) // self.accumulation_steps, range(len(self.train_dataloader)),
desc=f"Epoch {epoch + 1}/{self.max_epochs}", desc="Step",
disable=not is_rank_0(), disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
) )
if not self.apply_loss_mask: for i, batch in enumerate(self.train_dataloader):
chosen_loss_mask = chosen_loss_mask.fill_(1.0) batch = to_device(batch, self.device)
reject_loss_mask = reject_loss_mask.fill_(1.0) (
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
batch_size = chosen_input_ids.size()[0]
# Calculate logits from reference model.
if self.ref_model is not None:
self.ref_model.eval()
with torch.no_grad():
ref_all_logits = self.ref_model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
else:
logprob_ref_chosen = None
logprob_ref_reject = None
batch_size = chosen_input_ids.size()[0] # Merge chosen and reject
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
attention_mask = torch.stack(
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
)
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
logprob_ref = torch.stack([item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup])
actor_all_logits = self.model( data_iter = iter(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]), [
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), {
)["logits"] "input_ids": inputs_ids,
actor_chosen_logits = actor_all_logits[:batch_size] "attention_mask": attention_mask,
actor_reject_logits = actor_all_logits[batch_size:] "loss_mask": loss_mask,
logprob_actor_chosen = calc_masked_log_probs( "logprob_ref": logprob_ref,
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization }
) ]
)
rewards = []
logprob_actor_reject = calc_masked_log_probs( def _criterion(outputs, inputs):
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
) calc_masked_log_probs(
outputs["logits"][0::2],
if self.ref_model is not None: inputs["input_ids"][0::2],
self.ref_model.eval() inputs["loss_mask"][0::2][:, 1:],
with torch.no_grad(): self.length_normalization,
ref_all_logits = self.ref_model( ),
input_ids=torch.cat([chosen_input_ids, reject_input_ids]), calc_masked_log_probs(
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), outputs["logits"][1::2],
)["logits"] inputs["input_ids"][1::2],
ref_chosen_logits = ref_all_logits[:batch_size] inputs["loss_mask"][1::2][:, 1:],
ref_reject_logits = ref_all_logits[batch_size:] self.length_normalization,
logprob_ref_chosen = calc_masked_log_probs( ),
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
inputs["loss_mask"][0::2][:, 1:],
inputs["loss_mask"][1::2][:, 1:],
) )
logprob_ref_reject = calc_masked_log_probs( rewards.append(chosen_rewards)
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization rewards.append(rejected_rewards)
) return loss
else:
logprob_ref_chosen = None
logprob_ref_reject = None
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn( outputs = self.booster.execute_pipeline(
logprob_actor_chosen, data_iter,
logprob_actor_reject, self.model,
logprob_ref_chosen if logprob_ref_chosen is not None else None, criterion=_criterion,
logprob_ref_reject if logprob_ref_reject is not None else None, optimizer=self.optimizer,
chosen_loss_mask[:, 1:], return_loss=True,
reject_loss_mask[:, 1:], )
) loss = outputs["loss"]
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean() if self.booster.plugin.stage_manager.is_last_stage():
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
global_loss = all_reduce_mean(loss, self.plugin)
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix(
{
"train/loss": global_loss.item(),
"train/lr": self.actor_scheduler.get_last_lr()[0],
"train/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
"train/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
}
)
step_bar.update()
self.accumulative_meter.add("loss", global_loss.item())
self.accumulative_meter.add("chosen_rewards", chosen_rewards.to(torch.float16).mean().item())
self.accumulative_meter.add(
"rejected_rewards", rejected_rewards.to(torch.float16).mean().item()
)
if self.writer is not None:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), i)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), i
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
i,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards")
- self.accumulative_meter.get("rejected_rewards"),
i,
)
# DPO Loss
loss = losses.mean()
self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.actor_scheduler.step() self.actor_scheduler.step()
else:
# sync self.accumulative_meter.reset()
loss_mean = all_reduce_mean(tensor=loss) step_bar = trange(
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) len(self.train_dataloader) // self.accumulation_steps,
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards) desc=f"Epoch {epoch + 1}/{self.max_epochs}",
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies) disable=not is_rank_0(),
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) )
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) for i, batch in enumerate(self.train_dataloader):
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
step_bar.update()
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
)
self.accumulative_meter.reset()
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=i + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
self.ref_model.eval()
self.coordinator.print_on_master("\nStart evaluation...")
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
self.accumulative_meter.reset()
with torch.no_grad():
for i, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device) batch = to_device(batch, self.device)
( (
chosen_input_ids, chosen_input_ids,
@ -300,12 +290,11 @@ class DPOTrainer(SLTrainer):
batch_size = chosen_input_ids.size()[0] batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model( actor_all_logits = self.model(
torch.cat([chosen_input_ids, reject_input_ids]), input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]), attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"] )["logits"]
actor_chosen_logits = actor_all_logits[:batch_size] actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:] actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs( logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
) )
@ -314,22 +303,26 @@ class DPOTrainer(SLTrainer):
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
) )
self.ref_model.eval() if self.ref_model is not None:
self.ref_model.eval()
with torch.no_grad():
ref_all_logits = self.ref_model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
else:
logprob_ref_chosen = None
logprob_ref_reject = None
ref_all_logits = self.ref_model( loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen, logprob_actor_chosen,
logprob_actor_reject, logprob_actor_reject,
logprob_ref_chosen if logprob_ref_chosen is not None else None, logprob_ref_chosen if logprob_ref_chosen is not None else None,
@ -338,7 +331,9 @@ class DPOTrainer(SLTrainer):
reject_loss_mask[:, 1:], reject_loss_mask[:, 1:],
) )
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean() reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
loss = losses.mean()
self.booster.backward(loss=loss, optimizer=self.optimizer)
# sync
loss_mean = all_reduce_mean(tensor=loss) loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards) rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
@ -347,16 +342,301 @@ class DPOTrainer(SLTrainer):
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
self.accumulative_meter.add(
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
)
step_bar.update()
msg = "Evaluation Result:\n" if (i + 1) % self.accumulation_steps == 0:
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]: self.optimizer.step()
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" self.optimizer.zero_grad()
self.coordinator.print_on_master(msg) self.actor_scheduler.step()
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: step_bar.set_postfix(
f.write(msg) {
"train/loss": self.accumulative_meter.get("loss"),
"train/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
"train/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
"train/accuracy": self.accumulative_meter.get("accuracy"),
}
)
step_bar.update()
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards")
- self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
)
self.num_train_step += 1
self.accumulative_meter.reset()
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=self.num_train_step,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
self.ref_model.eval()
self.accumulative_meter.reset()
self.coordinator.print_on_master("\nStart evaluation...")
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
step_bar = tqdm(
range(len(self.eval_dataloader)),
desc="Step",
disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
with torch.no_grad():
for _, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
batch_size = chosen_input_ids.size()[0]
# Calculate logits from reference model.
if self.ref_model is not None:
self.ref_model.eval()
with torch.no_grad():
ref_all_logits = self.ref_model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
else:
logprob_ref_chosen = None
logprob_ref_reject = None
# Merge chosen and reject
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
attention_mask = torch.stack(
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
)
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
logprob_ref = torch.stack(
[item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup]
)
data_iter = iter(
[
{
"input_ids": inputs_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"logprob_ref": logprob_ref,
}
]
)
rewards = []
def _criterion(outputs, inputs):
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
calc_masked_log_probs(
outputs["logits"][0::2],
inputs["input_ids"][0::2],
inputs["loss_mask"][0::2][:, 1:],
self.length_normalization,
),
calc_masked_log_probs(
outputs["logits"][1::2],
inputs["input_ids"][1::2],
inputs["loss_mask"][1::2][:, 1:],
self.length_normalization,
),
inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
inputs["loss_mask"][0::2][:, 1:],
inputs["loss_mask"][1::2][:, 1:],
)
rewards.append(chosen_rewards)
rewards.append(rejected_rewards)
return loss
outputs = self.booster.execute_pipeline(
data_iter,
self.model,
criterion=_criterion,
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
global_loss = all_reduce_mean(loss, self.plugin)
chosen_rewards_mean = all_reduce_mean(chosen_rewards, self.plugin)
rejected_rewards_mean = all_reduce_mean(rejected_rewards, self.plugin)
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix(
{
"eval/loss": global_loss.item(),
"eval/lr": self.actor_scheduler.get_last_lr()[0],
"eval/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
"eval/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
}
)
self.accumulative_meter.add(
"chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()
)
self.accumulative_meter.add(
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
)
self.accumulative_meter.add("loss", global_loss.to(torch.float16).item())
step_bar.update()
if self.booster.plugin.stage_manager.is_last_stage():
msg = "\nEvaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
if dist.get_rank() == dist.get_world_size() - 1:
print(msg)
else:
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
with torch.no_grad():
for i, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
ref_all_logits = self.ref_model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen,
logprob_actor_reject,
logprob_ref_chosen if logprob_ref_chosen is not None else None,
logprob_ref_reject if logprob_ref_reject is not None else None,
chosen_loss_mask[:, 1:],
reject_loss_mask[:, 1:],
)
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
loss = losses.mean()
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add(
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
self.accumulative_meter.add(
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
)
step_bar.set_postfix(
{
"eval/loss": self.accumulative_meter.get("loss"),
"eval/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
"eval/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
"eval/accuracy": self.accumulative_meter.get("accuracy"),
}
)
step_bar.update()
msg = "\nEvaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close() step_bar.close()

View File

@ -73,8 +73,7 @@ def main():
"--conversation_template_config", "--conversation_template_config",
type=str, type=str,
default="conversation_template_config", default="conversation_template_config",
help="Path \ help="Path to save conversation template config files.",
to save conversation template config files.",
) )
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory") parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
parser.add_argument( parser.add_argument(

View File

@ -13,7 +13,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -29,8 +29,6 @@ def train(args):
# check lora compatibility # check lora compatibility
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0: if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ============================== # ==============================
# Initialize Distributed Training # Initialize Distributed Training
@ -46,7 +44,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose debugging purpose acceleration, for debugging purpose
""" """
plugin = TorchDDPPlugin(find_unused_parameters=True) plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
elif args.plugin == "gemini": elif args.plugin == "gemini":
plugin = GeminiPlugin( plugin = GeminiPlugin(
precision=args.mixed_precision, precision=args.mixed_precision,
@ -56,14 +54,6 @@ def train(args):
enable_gradient_accumulation=True, enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn, enable_flash_attention=args.use_flash_attn,
) )
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2": elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin( plugin = LowLevelZeroPlugin(
stage=2, stage=2,
@ -92,20 +82,24 @@ def train(args):
parallel_output=False, parallel_output=False,
max_norm=args.grad_clip, max_norm=args.grad_clip,
precision=args.mixed_precision, precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
) )
else: else:
raise ValueError(f"Unknown plugin {args.plugin}") raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
# ====================================================== ref_plugin = HybridParallelPlugin(
# Initialize Model, Objective, Optimizer and LR Scheduler tp_size=args.ref_tp,
# ====================================================== pp_size=1,
# Temp Fix: Disable lazy init due to version conflict zero_stage=args.zero_stage,
# init_ctx = ( enable_flash_attention=args.use_flash_attn,
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
# ) parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
ref_booster = Booster(plugin=ref_plugin)
init_ctx = nullcontext() init_ctx = nullcontext()
with init_ctx: with init_ctx:
@ -130,6 +124,7 @@ def train(args):
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain) ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
else: else:
ref_model = None ref_model = None
if args.lora_config is not None: if args.lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config) model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules(): for name, module in model.named_modules():
@ -139,7 +134,9 @@ def train(args):
disable_dropout(ref_model) disable_dropout(ref_model)
if args.grad_checkpoint: if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing # Make sure gradient checkpointing can be activated.
model.train()
# Note, for some models, lora may not be compatible with gradient checkpointing.
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
@ -169,7 +166,7 @@ def train(args):
adamw_mode=True, adamw_mode=True,
) )
# configure dataset # Configure dataset
coordinator.print_on_master(f"Load dataset: {args.dataset}") coordinator.print_on_master(f"Load dataset: {args.dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"} mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map) train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
@ -213,14 +210,15 @@ def train(args):
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype) torch.set_default_dtype(default_dtype)
model, optim, _, train_dataloader, lr_scheduler = booster.boost( model, optim, _, train_dataloader, lr_scheduler = booster.boost(
model=model, model=model,
optimizer=optim, optimizer=optim,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
dataloader=train_dataloader, dataloader=train_dataloader,
) )
if ref_model is not None: ref_model, _, _, _, _ = ref_booster.boost(model=ref_model)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
@ -312,7 +310,7 @@ if __name__ == "__main__":
"--plugin", "--plugin",
type=str, type=str,
default="gemini", default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], choices=["gemini", "zero2", "zero2_cpu", "3d", "ddp"],
help="Choose which plugin to use", help="Choose which plugin to use",
) )
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
@ -342,22 +340,35 @@ if __name__ == "__main__":
parser.add_argument("--max_length", type=int, default=2048, help="Model max length") parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3) parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument(
"--microbatch_size",
type=int,
default=2,
help="Micro batch size for PP training. To activate PP training for DPO-like algorithm, you must keep size even and the size should be equal or greater than 2.",
)
# Parameter for reference model
parser.add_argument( parser.add_argument(
"--disable_reference_model", "--disable_reference_model",
action="store_true", action="store_true",
default=False, default=False,
help="Disable the reference model (enabled by default)", help="Disable the reference model (enabled by default)",
) )
parser.add_argument("--disable_loss_mask", default=False, action="store_true") parser.add_argument(
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") "--ref_tp",
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path") type=int,
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") default=1,
parser.add_argument("--lr", type=float, default=5e-6) help="TP size for reference model; used only when reference model is too large.",
parser.add_argument("--accumulation_steps", type=int, default=8) )
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
# fool proof hyperparameter setup # fool proof hyperparameter setup

View File

@ -68,7 +68,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose debugging purpose acceleration, for debugging purpose
""" """
plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False) plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
elif args.plugin == "gemini": elif args.plugin == "gemini":
plugin = GeminiPlugin( plugin = GeminiPlugin(
precision=args.mixed_precision, precision=args.mixed_precision,

View File

@ -4,7 +4,7 @@ BASE_TEMP_DIR=$BASE_DIR/temp
EXAMPLES_DIR=$BASE_DIR/examples EXAMPLES_DIR=$BASE_DIR/examples
TEST_DATA_DIR=$BASE_DIR/tests/test_data TEST_DATA_DIR=$BASE_DIR/tests/test_data
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
CONFIG_DIR=$BASE_DIR/config CONFIG_DIR=$BASE_DIR/conversation_template
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test # MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi") MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
@ -39,23 +39,23 @@ get_pretrain() {
get_conversation_template_config() { get_conversation_template_config() {
local model=$1 local model=$1
if [[ $model == "colossal-llama2" ]]; then if [[ $model == "colossal-llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json" echo "$CONFIG_DIR/colossal-llama2.json"
elif [[ $model == "llama2" ]]; then elif [[ $model == "llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/llama2.json" echo "$CONFIG_DIR/llama2.json"
elif [[ $model == "deepseek" ]]; then elif [[ $model == "deepseek" ]]; then
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json" echo "$CONFIG_DIR/deepseek-ai_DeepSeek-V2-Lite.json"
elif [[ $model == "mistral" ]]; then elif [[ $model == "mistral" ]]; then
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json" echo "$CONFIG_DIR/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
elif [[ $model == "chatGLM2" ]]; then elif [[ $model == "chatGLM2" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json" echo "$CONFIG_DIR/THUDM_chatglm2-6b.json"
elif [[ $model == "chatGLM3" ]]; then elif [[ $model == "chatGLM3" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json" echo "$CONFIG_DIR/THUDM_chatglm3-6b.json"
elif [[ $model == "phi" ]]; then elif [[ $model == "phi" ]]; then
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json" echo "$CONFIG_DIR/microsoft_phi-2.json"
elif [[ $model == "Yi" ]]; then elif [[ $model == "Yi" ]]; then
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json" echo "$CONFIG_DIR/01-ai_Yi-1.5-9B-Chat.json"
elif [[ $model == "baichuan" ]]; then elif [[ $model == "baichuan" ]]; then
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json" echo "$CONFIG_DIR/baichuan-inc_Baichuan2-13B-Chat.json"
else else
echo "Unknown model $model" echo "Unknown model $model"
exit 1 exit 1
@ -71,6 +71,7 @@ for model in ${MODELS[@]}; do
rm -rf $SAVE_DIR/arrow rm -rf $SAVE_DIR/arrow
pretrain=$(get_pretrain $model) pretrain=$(get_pretrain $model)
conversation_template_config=$(get_conversation_template_config $model) conversation_template_config=$(get_conversation_template_config $model)
echo $conversation_template_config
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \ python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
--tokenizer_dir $pretrain \ --tokenizer_dir $pretrain \
--conversation_template_config $conversation_template_config \ --conversation_template_config $conversation_template_config \

View File

@ -279,4 +279,4 @@ class CudaAccelerator(BaseAccelerator):
""" """
Return autocast function Return autocast function
""" """
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)

View File

@ -322,7 +322,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
@ -366,7 +365,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention: bool = False, enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False, enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False, enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True, enable_async_reduce: bool = True,
use_fp8: bool = False, use_fp8: bool = False,
verbose: bool = False, verbose: bool = False,
@ -428,7 +426,6 @@ class GeminiPlugin(DPPluginBase):
self.enable_flash_attention = enable_flash_attention self.enable_flash_attention = enable_flash_attention
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
self.enable_jit_fused = enable_jit_fused self.enable_jit_fused = enable_jit_fused
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose self.verbose = verbose
self.tp_size = tp_size self.tp_size = tp_size
@ -455,7 +452,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention=self.enable_flash_attention, enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused, enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=self.enable_sequence_parallelism, enable_sequence_parallelism=self.enable_sequence_parallelism,
enable_sequence_overlap=self.enable_sequence_overlap,
) )
def __del__(self): def __del__(self):

View File

@ -116,10 +116,15 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
super().__init__(module) super().__init__(module)
self.op_hooks = [] self.op_hooks = []
if use_fp8:
self.op_hooks.append(FP8Hook())
self.op_hooks = []
if use_fp8: if use_fp8:
self.op_hooks.append(FP8Hook()) self.op_hooks.append(FP8Hook())
if overlap_allgather: if overlap_allgather:
self.op_hooks.append(ZeroOpHook()) self.op_hooks.append(ZeroOpHook())
if use_fp8 or overlap_allgather:
self.op_hooks.append(ZeroOpHook())
if use_fp8 or overlap_allgather: if use_fp8 or overlap_allgather:
for p in module.parameters(): for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter: if p.requires_grad and type(p) is not ColoParameter:
@ -232,6 +237,9 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
def _hook_context(self): def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext() return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
def get_param_info(optim: Optimizer): def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes: # Get a backup of necessary information of parameters for future use, which includes:
@ -951,7 +959,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism. microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
@ -983,6 +990,8 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn". inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default. It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
@ -1002,7 +1011,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused: bool = False, enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False, enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None, sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True, parallel_output: bool = True,
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
@ -1092,6 +1100,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.use_fp8 = use_fp8 self.use_fp8 = use_fp8
if dp_outside: if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
if sequence_parallelism_mode == "ring_attn": if sequence_parallelism_mode == "ring_attn":
# Swap tp and sp since 2D Ring has better inter-node latency # Swap tp and sp since 2D Ring has better inter-node latency
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
@ -1195,13 +1204,15 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused=self.enable_jit_fused, enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode, sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output, parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by, make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config, gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication, fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size, inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh,
sp_axis=self.sp_axis,
) )
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,
@ -1293,6 +1304,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.dp_size == 1 and self.pp_size == 1 self.dp_size == 1 and self.pp_size == 1
) )
# sync gradients across DP * SP ranks # sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks # Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])

View File

@ -290,7 +290,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class LowLevelZeroPlugin(DPPluginBase): class LowLevelZeroPlugin(DPPluginBase):

View File

@ -141,7 +141,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism. microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
@ -190,7 +189,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused: bool = False, enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False, enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None, sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True, parallel_output: bool = True,
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
@ -368,7 +366,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused=self.enable_jit_fused, enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode, sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output, parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by, make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config, gradient_checkpoint_config=gradient_checkpoint_config,

View File

@ -1,9 +1,11 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
@ -134,7 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class TorchDDPModel(ModelWrapper): class TorchDDPModel(ModelWrapper):

View File

@ -11,6 +11,7 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -20,7 +21,7 @@ from colossalai.tensor.padded_tensor import (
to_padded_tensor, to_padded_tensor,
to_unpadded_tensor, to_unpadded_tensor,
) )
from colossalai.utils import get_current_device from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
@ -104,8 +105,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
yield block, block_size yield block, block_size
# Save buffers. # Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers(): for name, buf in model.named_buffers():
if buf is not None and name not in model._non_persistent_buffers_set: if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach() buffer = buf if keep_vars else buf.detach()
block, block_size = state_dict_sharder.append_param(prefix + name, buffer) block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None: if block is not None:
@ -351,9 +353,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
_load(name) _load(name)
# Load buffers. # Load buffers.
non_persistent_buffers = set() non_persistent_buffers = get_non_persistent_buffers_set(model)
for n, m in model.named_modules():
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers(): for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers: if buf is not None and name not in non_persistent_buffers:
_load(name) _load(name)
@ -956,4 +956,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)

View File

@ -57,7 +57,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
target_module=NopadBaichuanMLP, target_module=NopadBaichuanMLP,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3} suffix="self_attn.W_pack",
target_module=FusedLinear1D_Col,
kwargs={"split_sizes": [self.model.config.hidden_size] * 3},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",

View File

@ -1,7 +1,6 @@
import torch import torch
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
from .bias_dropout_add import bias_dropout_add_fused_train from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl from .bias_gelu import bias_gelu_impl
@ -45,6 +44,7 @@ def warmup_jit_fusion(
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
): ):
"""Compile JIT functions before the main training steps""" """Compile JIT functions before the main training steps"""
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device()) embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device()) linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())

View File

@ -3,8 +3,9 @@ from typing import Any, List, Optional, Tuple
import torch import torch
import torch.cuda import torch.cuda
from packaging.version import Version
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten
# this register are for torch under version 1.13.1, maybe removed in the future # this register are for torch under version 1.13.1, maybe removed in the future
@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
return OrderedDict((key, value) for key, value in zip(context, values)) return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) if Version(torch.__version__) <= Version("1.13.1"):
try:
from torch.utils._pytree import register_pytree_node as _register_pytree_node
except ImportError:
from torch.utils._pytree import _register_pytree_node
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
def tree_map_hf(fn: Any, pytree: Any): def tree_map_hf(fn: Any, pytree: Any):

View File

@ -351,15 +351,16 @@ class InterleavedSchedule(PipelineSchedule):
if output_obj_grad is None: if output_obj_grad is None:
optimizer.backward(output_obj) optimizer.backward(output_obj)
else: else:
if "backward_tensor_keys" not in output_obj: keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
for k, grad in output_obj_grad.items(): tensors_to_backward = []
optimizer.backward_by_grad(output_obj[k], grad) grads_to_backward = []
for k in keys:
tensors_to_backward.append(output_obj[k])
grads_to_backward.append(output_obj_grad[k])
if len(tensors_to_backward) == 1:
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
else: else:
for k, grad in output_obj_grad.items(): optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
# Collect the grad of the input_obj. # Collect the grad of the input_obj.
input_obj_grad = None input_obj_grad = None

View File

@ -305,15 +305,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if output_obj_grad is None: if output_obj_grad is None:
optimizer.backward(output_obj) optimizer.backward(output_obj)
else: else:
if "backward_tensor_keys" not in output_obj: keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
for k, grad in output_obj_grad.items(): tensors_to_backward = []
optimizer.backward_by_grad(output_obj[k], grad) grads_to_backward = []
for k in keys:
tensors_to_backward.append(output_obj[k])
grads_to_backward.append(output_obj_grad[k])
if len(tensors_to_backward) == 1:
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
else: else:
for k, grad in output_obj_grad.items(): optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
# Collect the grad of the input_obj. # Collect the grad of the input_obj.
input_obj_grad = None input_obj_grad = None

View File

@ -8,6 +8,8 @@ import torch.nn.functional as F
from packaging.version import Version from packaging.version import Version
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
from .fp8_config import dynamic_kernel
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0") SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
SCALE_BYTES = 4 SCALE_BYTES = 4
try: try:
@ -832,11 +834,13 @@ class _LinearFp8(torch.autograd.Function):
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False) @torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias) return _LinearFp8.apply(input, weight, bias)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0:
return F.linear(input, weight, bias)
out = _linear_fp8(input, weight, bias) out = _linear_fp8(input, weight, bias)
return out return out

View File

@ -0,0 +1 @@
dynamic_kernel: bool = False

View File

@ -6,7 +6,7 @@ from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHe
from .loss import cross_entropy_1d, dist_cross_entropy from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
__all__ = [ __all__ = [
"Embedding1D", "Embedding1D",
@ -35,4 +35,5 @@ __all__ = [
"RingAttention", "RingAttention",
"get_pad_info", "get_pad_info",
"all_to_all_comm", "all_to_all_comm",
"FusedLinear1D_Row",
] ]

View File

@ -106,7 +106,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
grad_output = grad_output.view(-1, grad_output.shape[-1]) grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce and fp8_communication: if fp8_communication or not ctx.async_grad_allreduce:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2") _reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
elif ctx.async_grad_allreduce: elif ctx.async_grad_allreduce:
# Asynchronous all-reduce # Asynchronous all-reduce
@ -364,10 +364,12 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
for k in recv_tensors: for k in recv_tensors:
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k] send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
input_tensors = []
output_tensors = [] output_tensors = []
handles = communicate_step() handles = communicate_step()
# first round: special case, retrive from local tensor # first round: special case, retrive from local tensor
input_tensors.append(input_to_gather)
output_tensors.append(func(**input_to_gather, **input_local)) output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2): for i in range(group_size - 2):
for handle in handles: for handle in handles:
@ -378,14 +380,25 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
handles = communicate_step() handles = communicate_step()
# actual computation # actual computation
input_tensors.append(send_tensors)
output_tensors.append(func(**send_tensors, **input_local)) output_tensors.append(func(**send_tensors, **input_local))
# final round: special case, no need to send/recv again # final round: special case, no need to send/recv again
for handle in handles: for handle in handles:
handle.wait() handle.wait()
input_tensors.append(send_tensors)
output_tensors.append(func(**recv_tensors, **input_local)) output_tensors.append(func(**recv_tensors, **input_local))
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim) gathered_input = {}
for k in input_to_gather:
input_shards = [d[k] for d in input_tensors[group_size - cur_rank :] + input_tensors[: group_size - cur_rank]]
gathered_input[k] = torch.cat(input_shards, dim=gather_dim)
gathered_output = torch.cat(
output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim
)
return gathered_output, gathered_input
class _GatherForwardReduceScatterBackward(torch.autograd.Function): class _GatherForwardReduceScatterBackward(torch.autograd.Function):
@ -441,29 +454,30 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False): def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
ctx.save_for_backward(input_, weight, bias) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim ctx.dim = dim
ctx.overlap = overlap
if ring is True: if ring is True:
input_to_gather = {"input": input_} input_to_gather = {"input": input_}
input_local = {"weight": weight} input_local = {"weight": weight}
output = _ring_as_gather( output, input_dict = _ring_as_gather(
F.linear, F.linear,
input_to_gather=input_to_gather, input_to_gather=input_to_gather,
input_local=input_local, input_local=input_local,
process_group=process_group, process_group=process_group,
) )
ctx.gathered_input = input_dict["input"]
if bias is not None: if bias is not None:
output += bias output += bias
else: else:
input_parallel = _gather(input_, dim, process_group) input_parallel = _gather(input_, dim, process_group)
ctx.gathered_input = input_parallel
if bias is not None: if bias is not None:
output = F.linear(input_parallel, weight, bias) output = F.linear(input_parallel, weight, bias)
else: else:
@ -477,100 +491,50 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias use_bias = ctx.use_bias
dim = ctx.dim dim = ctx.dim
process_group = ctx.process_group process_group = ctx.process_group
overlap = ctx.overlap
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias: if use_bias:
bias = bias.view(bias.shape) bias = bias.view(bias.shape)
if not overlap: input_parallel = ctx.gathered_input
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel total_input = input_parallel
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility # Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2: if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1]) grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter: if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter # Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [ input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
] ]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
# wait until all-gather finished handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
gather_handle.wait() # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
# do reduce-scatter in async way if _grad_accum_fusion_available and weight.grad is not None:
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) grad = weight.grad
input_parallel = torch.cat(tensor_list, dim=dim).contiguous() if grad.dtype == torch.float32:
# calculate gradient fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
if len(input_parallel.shape) > 2: grad_weight = None
input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
if _grad_accum_fusion_available and weight.grad is not None: grad_weight = None
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input_parallel)
else: else:
grad_weight = grad_output.t().matmul(input_parallel) grad_weight = grad_output.t().matmul(total_input)
# grad_weight = grad_output.t().matmul(input_parallel) else:
# wait until reduce-scatter finished grad_weight = grad_output.t().matmul(total_input)
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
def _ring_as_reducescatter( def _ring_as_reducescatter(
@ -701,7 +665,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
# Convert the tensor shapes to 2D for execution compatibility # Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2: if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1]) grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1]) total_input = total_input.reshape(-1, total_input.shape[-1])
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
@ -759,34 +723,30 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward( def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
):
ctx.save_for_backward(input_, weight, bias) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim ctx.dim = dim
ctx.overlap = overlap
ctx.fp8_communication = fp8_communication ctx.fp8_communication = fp8_communication
if ring is True: if ring is True:
input_to_gather = {} input_to_gather = {"input": input_}
input_local = {} input_local = {"other": weight}
input_to_gather["input"] = input_
input_local["other"] = weight
output = _ring_as_gather( output, input_dict = _ring_as_gather(
torch.matmul, torch.matmul,
input_to_gather=input_to_gather, input_to_gather=input_to_gather,
input_local=input_local, input_local=input_local,
process_group=process_group, process_group=process_group,
gather_dim=dim, gather_dim=dim,
) )
ctx.gathered_input = input_dict["input"]
else: else:
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3") input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
ctx.gathered_input = input_parallel
output = torch.matmul(input_parallel, weight) output = torch.matmul(input_parallel, weight)
if bias is not None: if bias is not None:
@ -799,76 +759,39 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias use_bias = ctx.use_bias
dim = ctx.dim dim = ctx.dim
process_group = ctx.process_group process_group = ctx.process_group
overlap = ctx.overlap
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape) weight = weight.view(weight.shape)
if use_bias: if use_bias:
bias = bias.view(bias.shape) bias = bias.view(bias.shape)
if not overlap: input_parallel = ctx.gathered_input
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
total_input = input_parallel total_input = input_parallel
grad_input = grad_output.matmul(weight.T) grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility # Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2: if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1]) grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter: if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter # Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [ input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
] ]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
# wait until all-gather finished handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
gather_handle.wait() # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
# do reduce-scatter in async way grad_weight = total_input.t().matmul(grad_output)
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) grad_bias = grad_output.sum(dim=0) if use_bias else None
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None, None if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function):
@ -988,7 +911,7 @@ class _AllToAll(torch.autograd.Function):
ctx.gather_dim = gather_dim ctx.gather_dim = gather_dim
ctx.fp8_communication = fp8_communication ctx.fp8_communication = fp8_communication
world_size = dist.get_world_size(process_group) world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape bsz = input_.shape[0]
# using all_to_all_single when batch size is 1 # using all_to_all_single when batch size is 1
if bsz == 1: if bsz == 1:
@ -1019,7 +942,7 @@ class _AllToAll(torch.autograd.Function):
gather_dim = ctx.scatter_dim gather_dim = ctx.scatter_dim
fp8_communication = ctx.fp8_communication fp8_communication = ctx.fp8_communication
world_size = dist.get_world_size(process_group) world_size = dist.get_world_size(process_group)
bsz, _, _ = grad_output.shape bsz = grad_output.shape[0]
if bsz == 1: if bsz == 1:
return_grad = _all_to_all_single( return_grad = _all_to_all_single(
@ -1204,10 +1127,10 @@ def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=F
def linear_gather_forward_reducescatter_backward( def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
): ):
return _LinearWithGatherForwardReduceScatterBackward.apply( return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
) )
@ -1224,10 +1147,10 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
def matmul_gather_forward_reducescatter_backward( def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
): ):
return _MatmulWithGatherForwardReduceScatterBackward.apply( return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
) )

View File

@ -422,16 +422,21 @@ class RingAttention(torch.autograd.Function):
ATTN_DONE: torch.cuda.Event = None ATTN_DONE: torch.cuda.Event = None
SP_STREAM: torch.cuda.Stream = None SP_STREAM: torch.cuda.Stream = None
SP_GROUP: dist.ProcessGroup = None SP_GROUP: dist.ProcessGroup = None
# duplicate process group for concurrent NCCL streams
# while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7) # NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput,
# against this, in practice it seems to work fine. # both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
# LoongTrain's original double ring impl. uses concurrent PGs
# (https://github.com/InternLM/InternEvo/blob/e52f2ffc9acf818e8f2b1f97dfc69ceb2f06e154/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L192)
# but I confirmed with Pytorch developers this can cause obscure "Software caused connection abort" errors.
# (https://github.com/pytorch/pytorch/issues/132852)
# NOTE: In general, a smarter idea is put as many P2P calls as possible into one `batch_isend_irecv`.
INNER_RING_GROUP: dist.ProcessGroup = None INNER_RING_GROUP: dist.ProcessGroup = None
INNER_RING_GROUP_COPY: dist.ProcessGroup = None # INNER_RING_GROUP_COPY: dist.ProcessGroup = None
INTER_RING_GROUP: dist.ProcessGroup = None INTER_RING_GROUP: dist.ProcessGroup = None
INTER_RING_GROUP_COPY: dist.ProcessGroup = None # INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod @staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None): def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
""" """
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node. shouldn't be larger than the number of NICs on each node.
@ -441,21 +446,17 @@ class RingAttention(torch.autograd.Function):
Returns: Returns:
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
""" """
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
sp_group = pg_mesh.get_group_along_axis(sp_axis)
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group) sp_rank = dist.get_rank(sp_group)
if inner_ring_size is None: assert inner_ring_size is not None
if torch.cuda.device_count() >= dist.get_world_size():
# single node, no need to consider NICs assert (
return sp_group, sp_group inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
if sp_size <= 4: ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
inner_ring_size = min(2, sp_size)
else:
inner_ring_size = min(4, sp_size)
else:
assert (
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
if inner_ring_size == sp_size: if inner_ring_size == sp_size:
return sp_group, sp_group return sp_group, sp_group
@ -474,14 +475,14 @@ class RingAttention(torch.autograd.Function):
# Create inner ring groups # Create inner ring groups
for i in range(inner_ring_size): for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = dist.new_group(ranks) group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks: if sp_rank in ranks:
inner_ring_group = group inner_ring_group = group
# Create inter ring groups # Create inter ring groups
for i in range(num_rings): for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings)) ranks = list(range(i, sp_size, num_rings))
group = dist.new_group(ranks) group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks: if sp_rank in ranks:
inter_ring_group = group inter_ring_group = group
@ -492,7 +493,7 @@ class RingAttention(torch.autograd.Function):
q, # (B, H, Sq, D) q, # (B, H, Sq, D)
k, k,
v, v,
sp_group, sp_axis,
attention_mask_type, attention_mask_type,
cu_seqlens=None, cu_seqlens=None,
max_seqlen=None, max_seqlen=None,
@ -502,6 +503,7 @@ class RingAttention(torch.autograd.Function):
deterministic=False, deterministic=False,
return_softmax=False, return_softmax=False,
inner_ring_size=None, inner_ring_size=None,
pg_mesh=None,
**kwargs, **kwargs,
): ):
""" """
@ -512,7 +514,7 @@ class RingAttention(torch.autograd.Function):
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D] v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism sp_axis (Optional[int]): Sp axis for the global pg mesh.
sp_tream (torch.cuda.Stream): An different stream for output correction. sp_tream (torch.cuda.Stream): An different stream for output correction.
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q. of the sequences in the batch, used to index into q.
@ -537,7 +539,6 @@ class RingAttention(torch.autograd.Function):
RingAttention.ATTN_DONE = torch.cuda.Event() RingAttention.ATTN_DONE = torch.cuda.Event()
if RingAttention.SP_STREAM is None: if RingAttention.SP_STREAM is None:
RingAttention.SP_STREAM = torch.cuda.Stream() RingAttention.SP_STREAM = torch.cuda.Stream()
assert ( assert (
q.shape[2] == k.shape[2] q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\ ), "Q, K and V having different sequence lengths (inference or cross-attn)\
@ -546,11 +547,13 @@ class RingAttention(torch.autograd.Function):
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
), f"Mask type {attention_mask_type} is not supported yet." ), f"Mask type {attention_mask_type} is not supported yet."
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
if RingAttention.SP_GROUP is not sp_group: clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
sp_group = pg_mesh.get_group_along_axis(sp_axis)
if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size)
RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group
else: else:
@ -628,7 +631,13 @@ class RingAttention(torch.autograd.Function):
inner_ring_group: Optional[dist.ProcessGroup] = None, inner_ring_group: Optional[dist.ProcessGroup] = None,
inter_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None,
): ):
"""
Forward supporting both packed (varlen) and batched(fixed length, no padding) sequences.
No separate version for batched seq (hard to maintain), which incurs
some overhead in sequence splitting due to python for loops.
Uses two CUDA streams to overlap softmax denominator correction with next flash attn
(see comments below).
"""
cu_seqlens_q = cu_seqlens_kv = cu_seqlens cu_seqlens_q = cu_seqlens_kv = cu_seqlens
max_seqlen_q = max_seqlen_kv = max_seqlen max_seqlen_q = max_seqlen_kv = max_seqlen
cu_seqlens_half = cu_seqlens // 2 cu_seqlens_half = cu_seqlens // 2
@ -670,7 +679,8 @@ class RingAttention(torch.autograd.Function):
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group) sp_rank = dist.get_rank(sp_group)
# Attempt to achieve concurrent comm in the two-stream forward
# Create communicators corresponding to two CUDA streams
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)] local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
inter_ring_comm = RingComm(inter_ring_group) inter_ring_comm = RingComm(inter_ring_group)
local_sp_size = dist.get_world_size(inner_ring_group) local_sp_size = dist.get_world_size(inner_ring_group)
@ -678,7 +688,7 @@ class RingAttention(torch.autograd.Function):
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0 inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1 num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
# Non-contiguous indexing copies to a new contiguous tensor, # Any type of indexing(but not slicing) copies to a new contiguous tensor,
# so only do it once # so only do it once
if sp_rank != sp_size - 1: if sp_rank != sp_size - 1:
q1 = q[half_idx_back] q1 = q[half_idx_back]
@ -695,6 +705,7 @@ class RingAttention(torch.autograd.Function):
rng_states = [None for _ in range(sp_size)] rng_states = [None for _ in range(sp_size)]
sp_streams = [torch.cuda.current_stream(), sp_stream] sp_streams = [torch.cuda.current_stream(), sp_stream]
# Helper to pass args to FA
def _forward(q, k, v, causal): def _forward(q, k, v, causal):
( (
_, _,
@ -725,6 +736,7 @@ class RingAttention(torch.autograd.Function):
if i < local_sp_size - 1: if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
# Forward within a node
def _local_ring_forward(): def _local_ring_forward():
# (Hopefully) overlap output correction with next flash attn # (Hopefully) overlap output correction with next flash attn
for i in range(local_sp_size): for i in range(local_sp_size):
@ -733,6 +745,8 @@ class RingAttention(torch.autograd.Function):
# NOTE: waiting outside the current stream will NOT correctly synchronize. # NOTE: waiting outside the current stream will NOT correctly synchronize.
if i > 0: if i > 0:
local_kv_comms[(i + 1) % 2].wait() local_kv_comms[(i + 1) % 2].wait()
# Prefetch
if i == 0: if i == 0:
_kv_comm(i) _kv_comm(i)
@ -766,15 +780,22 @@ class RingAttention(torch.autograd.Function):
) = _forward(q_block, kv_block[0], kv_block[1], causal=False) ) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record() RingAttention.ATTN_DONE.record()
# Pipeline the next KV comm with output correction instead of the next flash attn # Pipeline the next KV comm with output correction instead of the next flash attn
# to minimize idle time when comm takes longer than attn. # kernel, to minimize bubble when comm takes longer than attn.
_kv_comm(i + 1) _kv_comm(i + 1)
block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
) # (H, T) -> (T, H, 1) ) # (H, T) -> (T, H, 1)
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
# In reality this always finishes before next flash attn; no need for extra sync. # Output and log sum exp correction.
# Ideally overlap this with the next flash attn kernel,
# since attn uses Tensor Core and rescale is element-wise, memory-bound and uses CUDA cores.
# (NOTE that this is the same as ping-pong scheduling idea in FA3)
# TODO However sometimes while the GPU has scheduled the next kernel,
# it's reluctant to launch it in overlap. Some potential causes:
# 1. need lower-level CUDA scheduling 2. further benchmark against Megatron-LM
# 3. register spilling by FA kernel.
if i == 0: if i == 0:
out = block_out[0] out = block_out[0]
softmax_lse = block_softmax_lse[0] softmax_lse = block_softmax_lse[0]
@ -790,15 +811,17 @@ class RingAttention(torch.autograd.Function):
torch.cuda.current_stream().wait_stream(sp_stream) torch.cuda.current_stream().wait_stream(sp_stream)
return out, softmax_lse return out, softmax_lse
# Forward for inter-node (the outer ring in 2D ring)
def _other_ring_forward(ring_num_idx, out, softmax_lse): def _other_ring_forward(ring_num_idx, out, softmax_lse):
# Loop through the inner ring after receiving # Loop through the inner ring after receiving
# all new KVs from the previous inner ring # all new KVs from another ring
for i in range(local_sp_size): for i in range(local_sp_size):
with torch.cuda.stream(sp_streams[i % 2]): with torch.cuda.stream(sp_streams[i % 2]):
# Send & recv KV # Send & recv KV
if i > 0: if i > 0:
local_kv_comms[(i + 1) % 2].wait() local_kv_comms[(i + 1) % 2].wait()
# Prefetch
if i == 0: if i == 0:
_kv_comm(i) _kv_comm(i)
@ -895,7 +918,8 @@ class RingAttention(torch.autograd.Function):
def backward(ctx, dout, _): def backward(ctx, dout, _):
""" """
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
over all ranks for accumulation. over all ranks for accumulation. We avoid using two streams due to backward using doubled
buffers and more comm cost.
""" """
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9] (q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
rng_states = ctx.saved_tensors[9:] rng_states = ctx.saved_tensors[9:]
@ -927,7 +951,7 @@ class RingAttention(torch.autograd.Function):
local_sp_rank = dist.get_rank(sp_group) local_sp_rank = dist.get_rank(sp_group)
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
# Using separate streams (pg) for concurrent kv and dkv comm may # NOTE: Using separate streams (PG) for concurrent kv and dkv comm may
# cause NCCL "software caused connection abort" here... # cause NCCL "software caused connection abort" here...
local_kv_comm = RingComm(local_kv_group) local_kv_comm = RingComm(local_kv_group)
local_dkv_comm = RingComm(local_kv_group) local_dkv_comm = RingComm(local_kv_group)
@ -959,6 +983,7 @@ class RingAttention(torch.autograd.Function):
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D) dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
del k, v del k, v
# Helper to pass args to FA
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal): def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
_flash_attn_backward( _flash_attn_backward(
dout, dout,
@ -979,8 +1004,7 @@ class RingAttention(torch.autograd.Function):
**misc_kwargs, **misc_kwargs,
) )
# NOTE: We avoid using two streams due to doubled buffers # Backward within a node
# and that backward is more communication intensive.
def _local_ring_backward(): def _local_ring_backward():
for i in range(local_sp_size): for i in range(local_sp_size):
if i > 0: if i > 0:
@ -1043,6 +1067,7 @@ class RingAttention(torch.autograd.Function):
dkv_send = dkv_buffers[(local_sp_size - 1) % 2] dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
return dq, dkv_recv, dkv_send return dq, dkv_recv, dkv_send
# Backward for inter-node (the outer ring in 2D ring)
def _other_ring_backward(ring_num_idx, dq): def _other_ring_backward(ring_num_idx, dq):
if ring_num_idx > inter_ring_rank: if ring_num_idx > inter_ring_rank:
# Indexing is expensive # Indexing is expensive
@ -1127,34 +1152,34 @@ class RingAttention(torch.autograd.Function):
@staticmethod @staticmethod
def prepare_varlen_batch( def prepare_varlen_batch(
attention_mask: torch.Tensor, padding_mask: torch.Tensor,
sp_group: dist.ProcessGroup, sp_group: dist.ProcessGroup,
inputs_embeds: torch.Tensor = None, inputs_embeds: torch.Tensor = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
is_label: bool = False, is_label: bool = False,
is_2d: bool = True, is_batched_seq: bool = True,
): ):
# TODO: support setting a batch dim (fix packing length) for packed mode, so that
# DP can be used (needs to modify dataloader too)
""" """
Preprocess a batch of padded sequence by splitting input sequence by sp_size Preprocess a batch of padded sequence by splitting input sequence by sp_size
sequence-wise and packing them into one sequence. Updates the mask info accordingly. seq-wise and packing them into one sequence. Updates the mask info accordingly.
Args: Args:
attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked. padding_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
sp_group (dist.ProcessGroup): Process group for sequence parallelism sp_group (dist.ProcessGroup): Process group for sequence parallelism
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...] inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None. position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
token of each sequence. token of each sequence.
is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten is_batched_seq (bool, optional): If True, then the input is a batch of (potentially padded) sequences
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. of shape [B, Sq, ...]; else a packed sequence of shape [T, ...].
Returns: Returns:
torch.Tensor: inputs_embeds (torch.Tensor):
Packed input embeddings of shape [B, Sq // sp_size, ...]. Packed input embeddings of shape [B, Sq // sp_size, ...] if is_batched_seq, else [T, ...].
mask_info (Dict[str, Any]):
Dict[str, Any]:
A dictionary containing mask info. A dictionary containing mask info.
position_ids (torch.Tensor):
torch.Tensor:
Packed position ids of shape [..., Sq // sp_size]. Packed position ids of shape [..., Sq // sp_size].
""" """
@ -1162,12 +1187,11 @@ class RingAttention(torch.autograd.Function):
sp_size = dist.get_world_size(group=sp_group) sp_size = dist.get_world_size(group=sp_group)
sp_rank = dist.get_rank(group=sp_group) sp_rank = dist.get_rank(group=sp_group)
mask_info = {} mask_info = {}
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False) mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(padding_mask, return_indices=False)
# Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size) # Unpad, split seq-wise, then pad to (B, max_seqlen // sp_size)
# Split mask to compute local nonzero position indices
# (B, Sq) -> (B, max_seqlen // sp_size) # (B, Sq) -> (B, max_seqlen // sp_size)
attention_mask = attention_mask[:, : mask_info["max_seqlen"]] padding_mask = padding_mask[:, : mask_info["max_seqlen"]]
if inputs_embeds is not None: if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]] inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
inputs_embeds = split_varlen_zigzag( inputs_embeds = split_varlen_zigzag(
@ -1175,11 +1199,12 @@ class RingAttention(torch.autograd.Function):
mask_info["cu_seqlens"], mask_info["cu_seqlens"],
sp_group, sp_group,
mask_info["max_seqlen"], mask_info["max_seqlen"],
is_2d=is_2d, is_batched_seq=is_batched_seq,
is_label=is_label, is_label=is_label,
) )
attention_mask = split_varlen_zigzag( # Split mask to get local nonzero seq positions
attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d padding_mask = split_varlen_zigzag(
padding_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_batched_seq=is_batched_seq
) )
if position_ids is not None: if position_ids is not None:
@ -1192,7 +1217,7 @@ class RingAttention(torch.autograd.Function):
) )
mask_info["max_seqlen"] //= sp_size mask_info["max_seqlen"] //= sp_size
mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() mask_info["valid_indices"] = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
mask_info["cu_seqlens"] //= sp_size mask_info["cu_seqlens"] //= sp_size
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
return inputs_embeds, mask_info, position_ids return inputs_embeds, mask_info, position_ids

View File

@ -23,18 +23,16 @@ from colossalai.tensor.d_tensor.api import (
) )
from ._operation import ( from ._operation import (
gather_forward_reducescatter_backward,
gather_forward_split_backward, gather_forward_split_backward,
linear_gather_forward_reducescatter_backward, linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward, linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
linear_with_grad_accum, linear_with_grad_accum,
reduce_forward, reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward, split_forward_gather_backward,
) )
from .parallel_module import PaddingParallelModule, ParallelModule from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"] __all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
@ -197,7 +195,6 @@ class Linear1D_Col(ParallelModule):
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`): weight_initializer (`typing.Callable`):
@ -220,7 +217,6 @@ class Linear1D_Col(ParallelModule):
gather_output: bool = False, gather_output: bool = False,
seq_parallel_mode: str = None, seq_parallel_mode: str = None,
seq_parallel_dim: int = 1, seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
@ -238,7 +234,6 @@ class Linear1D_Col(ParallelModule):
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.process_group = process_group self.process_group = process_group
@ -345,22 +340,16 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward( if is_share_sp_tp(self.seq_parallel_mode):
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication output_parallel = linear_gather_forward_reducescatter_backward(
)
output_parallel = linear_with_async_comm(
input_parallel, input_parallel,
self.weight, self.weight,
bias, bias,
self.process_group, self.process_group,
False, True,
fp8_communication=self.fp8_communication, self.seq_parallel_dim,
use_zbv=self.use_zbv, ring=self.seq_parallel_mode == "ring",
)
elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
) )
else: else:
output_parallel = linear_with_async_comm( output_parallel = linear_with_async_comm(
@ -584,31 +573,17 @@ class Linear1D_Row(ParallelModule):
handle.wait() handle.wait()
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
if self.seq_parallel_mode is None: if is_share_sp_tp(self.seq_parallel_mode):
output_parallel = linear_with_async_comm(
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output = linear_reducescatter_forward_gather_backward( output = linear_reducescatter_forward_gather_backward(
input_, input_,
self.weight, self.weight,
process_group=self.process_group, process_group=self.process_group,
dim=self.seq_parallel_dim, dim=self.seq_parallel_dim,
ring=True, ring=self.seq_parallel_mode == "ring",
) )
else: else:
output_parallel = linear_with_async_comm( output_parallel = F.linear(input_, self.weight)
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
)
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add: if not self.skip_bias_add:
if self.bias is not None: if self.bias is not None:
@ -716,7 +691,6 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`): weight_initializer (`typing.Callable`):

View File

@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -24,17 +25,17 @@ from colossalai.tensor.d_tensor.api import (
) )
from ._operation import ( from ._operation import (
gather_forward_split_backward, linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
matmul_gather_forward_reducescatter_backward, matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm, matmul_with_async_comm,
reduce_backward,
reduce_forward, reduce_forward,
reducescatter_forward_gather_backward, reducescatter_forward_gather_backward,
split_forward_gather_backward, split_forward_gather_backward,
) )
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"] __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
@ -44,21 +45,25 @@ __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col"
def split_fused_qkv_in_gpt2_style( def split_fused_qkv_in_gpt2_style(
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
): ):
""" """
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2]. The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
Args: Args:
qkv (torch.Tensor): The fused qkv tensor. qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value). split_sizes (List[int]): The sizes of the split tensor.
process_group (ProcessGroup): The process group for distributed communication. process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
""" """
# get the number of slice for the fused qkv # get the number of slice for the fused qkv
rank = dist.get_rank(group=process_group) rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group) world_size = dist.get_world_size(group=process_group)
order = torch.arange(world_size * n_fused) order = torch.arange(world_size * len(split_sizes))
new_split_sizes = []
for sz in split_sizes:
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
new_split_sizes.extend([sz // world_size] * world_size)
# split the fused qkv # split the fused qkv
# from # from
@ -66,9 +71,9 @@ def split_fused_qkv_in_gpt2_style(
# to # to
# [Q1, Q2, K1, K2, V1, V2] # [Q1, Q2, K1, K2, V1, V2]
if is_transposed: if is_transposed:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1) weight_chunks = torch.split(qkv, new_split_sizes, dim=-1)
else: else:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0) weight_chunks = torch.split(qkv, new_split_sizes, dim=0)
# rearrange the slice into the final order # rearrange the slice into the final order
# from # from
@ -85,18 +90,23 @@ def split_fused_qkv_in_gpt2_style(
def gather_fused_qkv_in_gpt2_style( def gather_fused_qkv_in_gpt2_style(
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
): ):
""" """
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2]. The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
Args: Args:
qkv (torch.Tensor): The fused qkv tensor. qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value). split_sizes (List[int]): The sizes of the split tensor.
process_group (ProcessGroup): The process group for distributed communication. process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
""" """
world_size = dist.get_world_size(group=process_group) world_size = dist.get_world_size(group=process_group)
new_split_sizes = []
for sz in split_sizes:
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
new_split_sizes.append(sz // world_size)
new_split_sizes = new_split_sizes * world_size
# gather the tensors # gather the tensors
# from # from
@ -121,13 +131,13 @@ def gather_fused_qkv_in_gpt2_style(
# to # to
# [Q1, Q2, K1, K2, V1, V2] # [Q1, Q2, K1, K2, V1, V2]
if is_transposed: if is_transposed:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1) weight_chunks = torch.split(gather_weight, new_split_sizes, dim=-1)
else: else:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0) weight_chunks = torch.split(gather_weight, new_split_sizes, dim=0)
reordered_chunk_list = [] reordered_chunk_list = []
for i in range(n_fused): for i in range(len(split_sizes)):
reordered_chunk_list.extend(weight_chunks[i::n_fused]) reordered_chunk_list.extend(weight_chunks[i :: len(split_sizes)])
if is_transposed: if is_transposed:
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1) reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
@ -136,6 +146,42 @@ def gather_fused_qkv_in_gpt2_style(
return reordered_gather_weight return reordered_gather_weight
class _SplitForwardGatherBackwardFusedQKV(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
ctx.split_sizes = split_sizes
ctx.process_group = process_group
return split_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
@staticmethod
def backward(ctx, grad_output):
grad_output = gather_fused_qkv_in_gpt2_style(
grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True
)
return grad_output, None, None
def split_forward_gather_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
return _SplitForwardGatherBackwardFusedQKV.apply(qkv, split_sizes, process_group)
class _GatherForwardSplitBackwardFusedQKV(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
ctx.split_sizes = split_sizes
ctx.process_group = process_group
return gather_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
@staticmethod
def backward(ctx, grad_output):
grad_output = split_fused_qkv_in_gpt2_style(grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True)
return grad_output, None, None
def gather_forward_split_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
return _GatherForwardSplitBackwardFusedQKV.apply(qkv, split_sizes, process_group)
class GPT2FusedLinearConv1D_Col(ParallelModule): class GPT2FusedLinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism. r"""Linear layer with column parallelism.
@ -145,10 +191,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
Args: Args:
in_features (int): size of each input sample. in_features (int): size of each input sample.
out_features (int): size of each output sample. out_features (int): size of each output sample.
split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None. dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None. device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available gather_output (bool, optional): If true, call all-gather on output and make Y available
@ -169,16 +215,14 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self, self,
in_features: int, in_features: int,
out_features: int, out_features: int,
split_sizes: List[int],
bias: bool = True, bias: bool = True,
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False, gather_output: bool = False,
seq_parallel_mode: str = None, seq_parallel_mode: str = None,
overlap: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@ -192,14 +236,16 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.out_features = out_features self.out_features = out_features
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_mode = seq_parallel_mode
self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.n_fused = n_fused self.split_sizes = split_sizes
self.process_group = process_group self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == out_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
if skip_bias_add and not bias: if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None") raise ValueError("cannot skip bias addition if bias is None")
@ -223,10 +269,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.weight = weight self.weight = weight
def shard_fn(tensor): def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
def gather_fn(tensor): def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
if not is_customized_distributed_tensor(self.weight): if not is_customized_distributed_tensor(self.weight):
with torch.no_grad(): with torch.no_grad():
@ -252,7 +298,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
@staticmethod @staticmethod
def from_native_module( def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]],
split_sizes: List[int],
*args,
**kwargs,
) -> ParallelModule: ) -> ParallelModule:
r""" r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
@ -260,7 +310,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
Args: Args:
module (`nn.Linear`): The module to be converted. module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight. split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
""" """
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# get the attributes # get the attributes
@ -291,6 +341,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
process_group=process_group, process_group=process_group,
weight=module.weight, weight=module.weight,
bias_=module.bias, bias_=module.bias,
split_sizes=split_sizes,
*args, *args,
**kwargs, **kwargs,
) )
@ -313,7 +364,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather": if is_share_sp_tp(self.seq_parallel_mode):
input_parallel = input_ input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward( output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel, input_parallel,
@ -322,31 +373,18 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.process_group, self.process_group,
True, True,
1, 1,
self.overlap, ring=self.seq_parallel_mode == "ring",
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode == "ring":
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel,
self.weight,
bias,
self.process_group,
True,
1,
self.overlap,
True,
fp8_communication=self.fp8_communication, fp8_communication=self.fp8_communication,
) )
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group) input_parallel = input_
output_parallel = matmul_with_async_comm( output_parallel = matmul_with_async_comm(
input_parallel, input_parallel,
self.weight, self.weight,
bias, bias,
self.process_group, self.process_group,
self.async_communication, True,
fp8_communication=self.fp8_communication, fp8_communication=self.fp8_communication,
) )
else: else:
@ -354,9 +392,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_forward_split_backward( output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
else: else:
output = output_parallel output = output_parallel
@ -565,7 +601,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn": if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
output_parallel = torch.matmul(input_, self.weight) output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather": elif is_share_sp_tp(self.seq_parallel_mode):
output_parallel = torch.matmul(input_, self.weight) output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward( output = reducescatter_forward_gather_backward(
output_parallel, output_parallel,
@ -573,13 +609,6 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
1, 1,
self.fp8_communication, self.fp8_communication,
) )
elif self.seq_parallel_mode == "ring":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel,
self.process_group,
1,
)
else: else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!") raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
@ -605,10 +634,10 @@ class FusedLinear1D_Col(ParallelModule):
Args: Args:
in_features (int): size of each input sample. in_features (int): size of each input sample.
out_features (int): size of each output sample. out_features (int): size of each output sample.
split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None. dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None. device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
@ -628,14 +657,15 @@ class FusedLinear1D_Col(ParallelModule):
self, self,
in_features: int, in_features: int,
out_features: int, out_features: int,
split_sizes: List[int],
bias: bool = True, bias: bool = True,
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False, gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
skip_bias_add: bool = False, skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@ -647,13 +677,18 @@ class FusedLinear1D_Col(ParallelModule):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.n_fused = n_fused self.split_sizes = split_sizes
self.process_group = process_group self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == out_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
if skip_bias_add and not bias: if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None") raise ValueError("cannot skip bias addition if bias is None")
@ -677,10 +712,10 @@ class FusedLinear1D_Col(ParallelModule):
self.weight = weight self.weight = weight
def shard_fn(tensor): def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
def gather_fn(tensor): def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False) return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
if not is_customized_distributed_tensor(self.weight): if not is_customized_distributed_tensor(self.weight):
with torch.no_grad(): with torch.no_grad():
@ -706,7 +741,11 @@ class FusedLinear1D_Col(ParallelModule):
@staticmethod @staticmethod
def from_native_module( def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]],
split_sizes: List[int],
*args,
**kwargs,
) -> ParallelModule: ) -> ParallelModule:
r""" r"""
Convert a fused `torch.nn.linear` layer to a parallelized linear layer. Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
@ -714,7 +753,7 @@ class FusedLinear1D_Col(ParallelModule):
Args: Args:
module (`nn.Linear`): The module to be converted. module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight. split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
""" """
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
@ -737,25 +776,11 @@ class FusedLinear1D_Col(ParallelModule):
process_group=process_group, process_group=process_group,
weight=module.weight, weight=module.weight,
bias_=module.bias, bias_=module.bias,
n_fused=n_fused, split_sizes=split_sizes,
*args, *args,
**kwargs, **kwargs,
) )
# # TODO: copy the sharded weights
# with torch.no_grad():
# sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
# n_fused=n_fused,
# process_group=process_group,
# is_transposed=False)
# linear_1d.weight.data.copy_(sharded_weight.data)
# if bias:
# sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
# n_fused=n_fused,
# process_group=process_group,
# is_transposed=False)
# linear_1d.bias.data.copy_(sharded_bias.data)
return linear_1d return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@ -772,19 +797,29 @@ class FusedLinear1D_Col(ParallelModule):
input_.shape, self.weight.shape, self.weight.shape[-1] input_.shape, self.weight.shape, self.weight.shape[-1]
) )
# Set up backprop all-reduce. # Set up backprop all-reduce.
# input_parallel = reduce_backward(input_, self.process_group)
input_parallel = input_ input_parallel = input_
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if is_share_sp_tp(self.seq_parallel_mode):
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel,
self.weight,
bias,
self.process_group,
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_forward_split_backward( output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
else: else:
output = output_parallel output = output_parallel
@ -792,3 +827,196 @@ class FusedLinear1D_Col(ParallelModule):
return output, self.bias return output, self.bias
else: else:
return output return output
class FusedLinear1D_Row(ParallelModule):
r"""Linear layer with row parallelism
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
split_sizes: List[int],
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.split_sizes = split_sizes
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == in_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to in_features({in_features})."
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
# Initialize weight.
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], split_sizes: List[int], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
linear_1d = FusedLinear1D_Row(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
split_sizes=split_sizes,
**kwargs,
)
return linear_1d
@torch.no_grad()
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
if self.process_group is None:
src_rank = 0
else:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
origin_device = self.bias.device
bias = self.bias.cuda()
dist.broadcast(bias, src=src_rank, group=self.process_group)
bias = bias.to(origin_device)
self.bias.copy_(bias)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
input_ = input_
else:
assert (
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
)
input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
if is_share_sp_tp(self.seq_parallel_mode):
output = linear_reducescatter_forward_gather_backward(
input_,
self.weight,
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias

View File

@ -295,8 +295,8 @@ def split_batch_zigzag(
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, List[torch.Tensor]]:
""" """
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask Split the input sequence batch . Naively spliting the attention mask in the causal setting
in the causal setting will result in the preceding ranks having much less workload. will result in the preceding ranks having much less workload.
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2). We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |. For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
@ -346,40 +346,42 @@ def split_varlen_zigzag(
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
sp_group: ProcessGroup, sp_group: ProcessGroup,
max_seqlen: int = 0, max_seqlen: int = 0,
is_2d: bool = False, is_batched_seq: bool = False,
is_label: bool = False, is_label: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]: ) -> Union[List[torch.Tensor], torch.Tensor]:
"""Split each sequence in a batch of packed sequences in a zigzag fashion. """Split a packed seq/batch of padded sequences in a Zigzag fashion.
For each tensor in batch, return packed sequences if is_2d is False; Different from split_batch_zigzag, inputs here have variable sequence lengths.
else return a padded batch of sequences.
Args: Args:
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d. batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq,
where T is the total number of tokens.
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting. cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
sp_group (ProcessGroup): The process group for sequence parallelism. sp_group (ProcessGroup): The process group for sequence parallelism.
max_seqlen (int): The maximum sequence length in the batch before splitting. max_seqlen (int): The maximum sequence length in the batch before splitting.
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions. is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len.
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>). is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
Returns: Returns:
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size) batch (List[torch.Tensor]): Packed sequences of shape (T, ..)
or (B, max_seqlen // sp_size, ...) if is_2d or (B, max_seqlen // sp_size, ...) if is_batched_seq
""" """
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group) sp_rank = dist.get_rank(sp_group)
if sp_size == 1: if sp_size == 1:
return batch return batch
if is_2d: if is_batched_seq:
assert max_seqlen > 0, "max_seqlen must be provided for 2D input" assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
if isinstance(batch, torch.Tensor): if isinstance(batch, torch.Tensor):
batch = [batch] batch = [batch]
# seq: (B, Sq, h, n)
# seq = seq[:, :rank * (seqlen // sp_size), ...]
for i, packed_seq in enumerate(batch): for i, packed_seq in enumerate(batch):
device = packed_seq.device device = packed_seq.device
dtype = packed_seq.dtype dtype = packed_seq.dtype
if is_2d: if is_batched_seq:
assert max_seqlen % (sp_size * 2) == 0 assert max_seqlen % (sp_size * 2) == 0
# Recreate a padded tensor with the new max seqlen # Recreate a padded tensor with the new max seqlen
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:]) shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
@ -398,7 +400,7 @@ def split_varlen_zigzag(
seqlen % (2 * sp_size) == 0 seqlen % (2 * sp_size) == 0
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting" ), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
if is_2d: if is_batched_seq:
seq = packed_seq[j][:seqlen] seq = packed_seq[j][:seqlen]
if is_label: if is_label:
# Shift one position to the right for next token prediction # Shift one position to the right for next token prediction
@ -415,7 +417,7 @@ def split_varlen_zigzag(
seq = seq.chunk(sp_size * 2) seq = seq.chunk(sp_size * 2)
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]]) local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
if is_2d: if is_batched_seq:
batch[i] = local_seq.contiguous() batch[i] = local_seq.contiguous()
else: else:
batch[i] = torch.cat(local_seq, dim=0) batch[i] = torch.cat(local_seq, dim=0)

View File

@ -857,17 +857,17 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
dropout_p = self.attn_dropout.p if self.training else 0.0 dropout_p = self.attn_dropout.p if self.training else 0.0
sp_mode = shard_config.sequence_parallelism_mode sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
if sp_mode == "ring_attn": if sp_mode == "ring_attn":
attn_output = RingAttention.attention( attn_output = RingAttention.attention(
query, query,
key, key,
value, value,
sp_group, sp_axis=shard_config.sp_axis,
**attention_mask, **attention_mask,
dropout_p=dropout_p, dropout_p=dropout_p,
scale=scale, scale=scale,
inner_ring_size=shard_config.inner_ring_size, inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
) )
else: else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)

View File

@ -271,6 +271,7 @@ class LlamaPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
**kwargs,
): ):
r""" r"""
Args: Args:
@ -568,9 +569,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
query_states, query_states,
key_states, key_states,
value_states, value_states,
sp_group, sp_axis=shard_config.sp_axis,
**attention_mask, **attention_mask,
inner_ring_size=shard_config.inner_ring_size, inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
) )
elif shard_config.enable_flash_attention: elif shard_config.enable_flash_attention:

View File

@ -73,7 +73,6 @@ class BertPolicy(Policy):
) )
sp_mode = "split_gather" sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather" sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
@ -97,7 +96,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -106,7 +104,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -115,7 +112,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -140,7 +136,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused, "skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },

View File

@ -71,7 +71,7 @@ class BlipPolicy(Policy):
suffix="self_attn.qkv", suffix="self_attn.qkv",
target_module=col_nn.FusedLinear1D_Col, target_module=col_nn.FusedLinear1D_Col,
kwargs={ kwargs={
"n_fused": 3, "split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),

View File

@ -57,7 +57,6 @@ class BloomPolicy(Policy):
) )
sp_mode = "split_gather" sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather" sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
@ -78,7 +77,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -99,7 +97,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),

View File

@ -67,7 +67,6 @@ class ChatGLMPolicy(Policy):
f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather" f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
) )
sp_mode = "split_gather" sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather"] sp_partial_derived = sp_mode in ["split_gather"]
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
@ -127,7 +126,6 @@ class ChatGLMPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0, "seq_parallel_dim": 0,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),

View File

@ -65,7 +65,6 @@ class GPT2Policy(Policy):
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
) )
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather" self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention use_flash_attention = self.shard_config.enable_flash_attention
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
@ -92,9 +91,8 @@ class GPT2Policy(Policy):
suffix="attn.c_attn", suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 3, "split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -107,9 +105,8 @@ class GPT2Policy(Policy):
suffix="mlp.c_fc", suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 1, "split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused, "skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },

View File

@ -51,7 +51,6 @@ class GPTJPolicy(Policy):
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -76,7 +75,6 @@ class GPTJPolicy(Policy):
suffix="attn.k_proj", suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -84,7 +82,6 @@ class GPTJPolicy(Policy):
suffix="attn.q_proj", suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),
@ -92,7 +89,6 @@ class GPTJPolicy(Policy):
suffix="attn.v_proj", suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),

View File

@ -42,7 +42,7 @@ class SamPolicy(Policy):
suffix="attn.qkv", suffix="attn.qkv",
target_module=col_nn.FusedLinear1D_Col, target_module=col_nn.FusedLinear1D_Col,
kwargs={ kwargs={
"n_fused": 3, "split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
}, },
), ),

View File

@ -26,7 +26,6 @@ class ShardConfig:
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
@ -44,13 +43,14 @@ class ShardConfig:
enable_jit_fused: bool = False enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False enable_sequence_parallelism: bool = False
sequence_parallelism_mode: str = None sequence_parallelism_mode: str = None
enable_sequence_overlap: bool = False
parallel_output: bool = True parallel_output: bool = True
make_vocab_size_divisible_by: int = 64 make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention # For ring attention
sp_axis: Optional[int] = None
pg_mesh: Optional[int] = None
inner_ring_size: Optional[int] = None inner_ring_size: Optional[int] = None
# for moe related # for moe related
moe_dp_group: Optional[ProcessGroup] = None moe_dp_group: Optional[ProcessGroup] = None
@ -84,24 +84,12 @@ class ShardConfig:
assert ( assert (
self.enable_tensor_parallelism self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
elif self.sequence_parallelism_mode in ["all_to_all"]:
# assert (
# not self.enable_tensor_parallelism
# ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
if self.enable_sequence_overlap:
self.enable_sequence_overlap = False
warnings.warn(
f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}"
)
else: else:
if self.sequence_parallelism_mode: if self.sequence_parallelism_mode:
self.sequence_parallelism_mode = None self.sequence_parallelism_mode = None
warnings.warn( warnings.warn(
f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False" f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False"
) )
assert (
not self.enable_sequence_overlap
), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True"
# get the tensor parallel size # get the tensor parallel size
if not self.enable_tensor_parallelism: if not self.enable_tensor_parallelism:
@ -134,4 +122,3 @@ class ShardConfig:
# This can cause non-in-place param sharding when used without ZeRO. # This can cause non-in-place param sharding when used without ZeRO.
# It may also slow down training when seq len is small. Plz enable manually. # It may also slow down training when seq len is small. Plz enable manually.
# self.enable_sequence_parallelism = True # self.enable_sequence_parallelism = True
# self.enable_sequence_overlap = True

View File

@ -5,6 +5,7 @@ from .common import (
ensure_path_exists, ensure_path_exists,
free_storage, free_storage,
get_current_device, get_current_device,
get_non_persistent_buffers_set,
is_ddp_ignored, is_ddp_ignored,
set_seed, set_seed,
) )
@ -25,4 +26,5 @@ __all__ = [
"set_seed", "set_seed",
"get_current_device", "get_current_device",
"is_ddp_ignored", "is_ddp_ignored",
"get_non_persistent_buffers_set",
] ]

View File

@ -5,10 +5,11 @@ import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable, Optional, Set
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
@ -76,3 +77,34 @@ def set_seed(seed):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
def get_non_persistent_buffers_set(
module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""
if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
)
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
child_non_persistent_set = get_non_persistent_buffers_set(
sub_module, memo, submodule_prefix, remove_duplicate
)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set

View File

@ -0,0 +1,64 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json
from dataclasses import asdict, dataclass
from typing import Dict, List, Tuple
import torch
from safetensors.torch import _TYPES
try:
from tensornvme.async_file_io import AsyncFileWriter
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
_TYPES_INV = {v: k for k, v in _TYPES.items()}
@dataclass
class TensorInfo:
dtype: str
shape: List[int]
data_offsets: Tuple[int, int]
@dataclass
class PreparedData:
n: int
header_bytes: bytes
offset: int
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]:
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
tensors = []
metadata = {}
offset = 0
for name, tensor in sorted_data:
n = tensor.numel() * tensor.element_size()
tensor_info = TensorInfo(
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)
)
offset += n
metadata[name] = asdict(tensor_info)
tensors.append(tensor)
metadata_buf = json.dumps(metadata).encode("utf-8")
extra = (8 - len(metadata_buf) % 8) % 8
metadata_buf += b" " * extra
n = len(metadata_buf)
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
prepared_data, tensors = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)
for tensor in tensors:
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)

View File

@ -35,7 +35,7 @@ from colossalai.tensor.padded_tensor import (
to_unpadded_tensor, to_unpadded_tensor,
) )
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored from colossalai.utils import _cast_float, free_storage, get_non_persistent_buffers_set, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook from .gemini_hook import GeminiZeROHook
@ -187,7 +187,7 @@ class GeminiDDP(ModelWrapper):
pin_memory=pin_memory, pin_memory=pin_memory,
) )
super().__init__(module) super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._non_persistent_buffers_set = get_non_persistent_buffers_set(module)
self._cast_buffers() self._cast_buffers()
# register grad hook # register grad hook
@ -257,36 +257,6 @@ class GeminiDDP(ModelWrapper):
for p in params_to_ignore: for p in params_to_ignore:
p._ddp_to_ignore = True p._ddp_to_ignore = True
def _get_non_persistent_buffers_set(
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""
if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
)
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
child_non_persistent_set = self._get_non_persistent_buffers_set(
sub_module, memo, submodule_prefix, remove_duplicate
)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set
def _post_forward(self): def _post_forward(self):
"""This function is only triggered for inference.""" """This function is only triggered for inference."""
access_list = list(self.chunk_manager.accessed_chunks) access_list = list(self.chunk_manager.accessed_chunks)

View File

@ -1,10 +1,5 @@
import torch.nn import torch.nn
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float from colossalai.utils import _cast_float
@ -27,6 +22,12 @@ class RuntimeMemTracer:
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half): def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
super().__init__() super().__init__()
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
self.module = module self.module = module
self.dtype = dtype self.dtype = dtype
self._gradstat = GradMemStats() self._gradstat = GradMemStats()

View File

@ -8,7 +8,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
@ -172,6 +171,8 @@ class AutoPlacementPolicy(PlacementPolicy):
Returns: Returns:
int: the volume of memory that is evicted int: the volume of memory that is evicted
""" """
from colossalai.legacy.utils.memory import colo_device_memory_capacity
start = time() start = time()
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem["cuda"] used_cuda_model_data = self.chunk_manager.total_mem["cuda"]

View File

@ -25,15 +25,13 @@
</div> </div>
## 新闻 ## 新闻
* [2024/10] [How to build a low-cost Sora-like app? Solutions for you](https://company.hpc-ai.com/blog/how-to-build-a-low-cost-sora-like-app-solutions-for-you)
* [2024/09] [Singapore Startup HPC-AI Tech Secures 50 Million USD in Series A Funding to Build the Video Generation AI Model and GPU Platform](https://company.hpc-ai.com/blog/singapore-startup-hpc-ai-tech-secures-50-million-usd-in-series-a-funding-to-build-the-video-generation-ai-model-and-gpu-platform)
* [2024/09] [Reducing AI Large Model Training Costs by 30% Requires Just a Single Line of Code From FP8 Mixed Precision Training Upgrades](https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades)
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use) * [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference) * [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) * [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series) * [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
* [2024/03] [Open-SoraSora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
## 目录 ## 目录
<ul> <ul>

View File

@ -16,7 +16,7 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan)
AMP stands for automatic mixed precision training. AMP stands for automatic mixed precision training.
In Colossal-AI, we have incorporated different implementations of mixed precision training: In Colossal-AI, we have incorporated different implementations of mixed precision training:
1. torch.cuda.amp 1. torch.amp
2. apex.amp 2. apex.amp
3. naive amp 3. naive amp

View File

@ -16,7 +16,7 @@
AMP 代表自动混合精度训练。 AMP 代表自动混合精度训练。
在 Colossal-AI 中, 我们结合了混合精度训练的不同实现: 在 Colossal-AI 中, 我们结合了混合精度训练的不同实现:
1. torch.cuda.amp 1. torch.amp
2. apex.amp 2. apex.amp
3. naive amp 3. naive amp

View File

@ -163,6 +163,8 @@ def main():
enable_async_reduce=not args.disable_async_reduce, enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
plugin = GeminiPlugin( plugin = GeminiPlugin(
@ -177,6 +179,8 @@ def main():
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "fsdp": elif args.plugin == "fsdp":
if use_empty_init: if use_empty_init:
@ -188,6 +192,7 @@ def main():
), ),
param_init_fn=empty_init(), param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
) )
else: else:
plugin = TorchFSDPPlugin( plugin = TorchFSDPPlugin(
@ -209,6 +214,7 @@ def main():
cpu_offload=CPUOffload(offload_params=True), cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(), param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
) )
else: else:
plugin = TorchFSDPPlugin( plugin = TorchFSDPPlugin(
@ -219,6 +225,7 @@ def main():
), ),
cpu_offload=CPUOffload(offload_params=True), cpu_offload=CPUOffload(offload_params=True),
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "3d": elif args.plugin == "3d":
if args.pp_style == "zbv": if args.pp_style == "zbv":

View File

@ -79,7 +79,7 @@ class _CppExtension(_Extension):
# check if the kernel has been built # check if the kernel has been built
compiled_before = False compiled_before = False
kernel_file_path = build_directory.joinpath(f"{self.name}.o") kernel_file_path = build_directory.joinpath(f"{self.name}.so")
if kernel_file_path.exists(): if kernel_file_path.exists():
compiled_before = True compiled_before = True

View File

@ -74,7 +74,7 @@ class _CudaExtension(_CppExtension):
# check if the kernel has been built # check if the kernel has been built
compiled_before = False compiled_before = False
kernel_file_path = build_directory.joinpath(f"{self.name}.o") kernel_file_path = build_directory.joinpath(f"{self.name}.so")
if kernel_file_path.exists(): if kernel_file_path.exists():
compiled_before = True compiled_before = True

View File

@ -41,22 +41,7 @@ class Conv1D(nn.Module):
return x return x
def rearrange(tensor: torch.Tensor, dim: int): def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str):
tensor = tensor.clone()
world_size = 2
order = torch.arange(world_size * 3)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
return rearanged_tensor
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda() linear = Conv1D(192, 48).cuda()
with ctx: with ctx:
@ -66,8 +51,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
process_group=None, process_group=None,
gather_output=True, gather_output=True,
seq_parallel_mode=seq_parallel_mode, seq_parallel_mode=seq_parallel_mode,
n_fused=3, split_sizes=[64] * 3,
overlap=overlap,
) )
assert linear.weight.shape == torch.Size([48, 192]) assert linear.weight.shape == torch.Size([48, 192])
@ -88,13 +72,13 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
) )
gather_out = linear_conv_col(x_for_shard) gather_out = linear_conv_col(x_for_shard)
assert_close(rearrange(out, -1), gather_out) assert_close(out, gather_out)
# check backward correctness # check backward correctness
out.sum().backward() out.sum().backward()
gather_out.sum().backward() gather_out.sum().backward()
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [64] * 3, None, True)
assert_close(target_grad, linear_conv_col.weight.grad) assert_close(target_grad, linear_conv_col.weight.grad)
@ -136,9 +120,8 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
@parameterize("lazy_init", [False, True]) @parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", ["split_gather", None]) @parameterize("seq_parallel_mode", ["split_gather", None])
@parameterize("overlap", [True]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
check_linear_conv_1d_row(lazy_init, seq_parallel_mode) check_linear_conv_1d_row(lazy_init, seq_parallel_mode)

View File

@ -2,13 +2,12 @@ import os
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -16,93 +15,55 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
self.weight = nn.Parameter(torch.empty(nx, nf))
self.bias = nn.Parameter(torch.zeros(nf))
nn.init.normal_(self.weight, std=0.02)
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
def rearrange(tensor: torch.Tensor, dim: int):
tensor = tensor.clone()
world_size = 2
order = torch.arange(world_size * 3)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
return rearanged_tensor
@parameterize("lazy_init", [False, True]) @parameterize("lazy_init", [False, True])
def check_linear_conv_1d_col(lazy_init: bool): def check_linear_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda() linear = nn.Linear(8, 80).cuda()
with ctx: with ctx:
linear_copy = Conv1D(192, 48).cuda() linear_copy = nn.Linear(8, 80).cuda()
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( linear_col = FusedLinear1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, n_fused=3 linear_copy, process_group=None, gather_output=True, split_sizes=[32, 32, 16]
) )
assert linear.weight.shape == torch.Size([48, 192]) assert linear.weight.shape == torch.Size([80, 8])
assert linear.bias.shape == torch.Size([192]) assert linear.bias.shape == torch.Size([80])
assert linear_conv_col.weight.shape == torch.Size([48, 96]) assert linear_col.weight.shape == torch.Size([40, 8])
assert linear_conv_col.bias.shape == torch.Size([96]) assert linear_col.bias.shape == torch.Size([40])
assert linear_copy.weight is linear_conv_col.weight assert linear_copy.weight is linear_col.weight
assert linear_copy.bias is linear_conv_col.bias assert linear_copy.bias is linear_col.bias
# ensure weights are reversibly loadable # ensure weights are reversibly loadable
linear_conv_col.load_state_dict(linear.state_dict()) linear_col.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_conv_col.state_dict()) linear.load_state_dict(linear_col.state_dict())
# check computation correctness # check computation correctness
x = torch.rand(4, 48).cuda() x = torch.rand(4, 8).cuda()
out = linear(x) out = linear(x)
gather_out = linear_conv_col(x) gather_out = linear_col(x)
assert_close(rearrange(out, 1), gather_out) assert_close(out, gather_out)
# check backward correctness # check backward correctness
out.sum().backward() out.sum().backward()
gather_out.sum().backward() gather_out.sum().backward()
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, False)
assert_close(target_grad, linear_conv_col.weight.grad) assert_close(target_grad, linear_col.weight.grad)
@parameterize("lazy_init", [False, True]) @parameterize("lazy_init", [False, True])
def check_linear_conv_1d_row(lazy_init: bool): def check_linear_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda() linear = nn.Linear(80, 8).cuda()
with ctx: with ctx:
linear_copy = Conv1D(192, 48).cuda() linear_copy = nn.Linear(80, 8).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) linear_row = FusedLinear1D_Row.from_native_module(
linear_copy, process_group=None, split_sizes=[32, 32, 16], parallel_input=False
)
assert linear.weight.shape == torch.Size([48, 192]) assert linear.weight.shape == torch.Size([8, 80])
assert linear_row.weight.shape == torch.Size([24, 192]) assert linear_row.weight.shape == torch.Size([8, 40])
assert linear_row.bias.shape == torch.Size([192]) assert linear_row.bias.shape == torch.Size([8])
assert linear_copy.weight is linear_row.weight assert linear_copy.weight is linear_row.weight
assert linear_copy.bias is linear_row.bias assert linear_copy.bias is linear_row.bias
@ -111,7 +72,7 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear.load_state_dict(linear_row.state_dict()) linear.load_state_dict(linear_row.state_dict())
# check computation correctness # check computation correctness
x = torch.rand(4, 48).cuda() x = torch.rand(4, 80).cuda()
out = linear(x) out = linear(x)
gather_out = linear_row(x) gather_out = linear_row(x)
assert_close(out, gather_out) assert_close(out, gather_out)
@ -120,17 +81,51 @@ def check_linear_conv_1d_row(lazy_init: bool):
out.sum().backward() out.sum().backward()
gather_out.sum().backward() gather_out.sum().backward()
rank = dist.get_rank() target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, True)
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
assert_close(target_grad, linear_row.weight.grad) assert_close(target_grad, linear_row.weight.grad)
@parameterize("lazy_init", [False, True])
def check_linear_1d_col_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear1 = nn.Linear(8, 80).cuda()
linear2 = nn.Linear(80, 8).cuda()
with ctx:
linear1_copy = nn.Linear(8, 80).cuda()
linear2_copy = nn.Linear(80, 8).cuda()
linear_col = FusedLinear1D_Col.from_native_module(linear1_copy, process_group=None, split_sizes=[32, 32, 16])
linear_row = FusedLinear1D_Row.from_native_module(
linear2_copy,
process_group=None,
split_sizes=[32, 32, 16],
)
# ensure weights are reversibly loadable
linear_col.load_state_dict(linear1.state_dict())
linear_row.load_state_dict(linear2.state_dict())
# check computation correctness
x = torch.rand(4, 8).cuda()
target_out = linear2(linear1(x))
out = linear_row(linear_col(x))
assert_close(out, target_out)
# check backward correctness
target_out.sum().backward()
out.sum().backward()
target_grad1 = split_fused_qkv_in_gpt2_style(linear1.weight.grad, [32, 32, 16], None, False)
assert_close(target_grad1, linear_col.weight.grad)
target_grad2 = split_fused_qkv_in_gpt2_style(linear2.weight.grad, [32, 32, 16], None, True)
assert_close(target_grad2, linear_row.weight.grad)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# test for linear conv check_linear_1d_col()
check_linear_conv_1d_col() check_linear_1d_row()
check_linear_conv_1d_row() check_linear_1d_col_row()
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()

View File

@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
@ -17,11 +18,14 @@ from colossalai.utils import get_current_device
@parameterize("nheads", [5]) @parameterize("nheads", [5])
@parameterize("d", [128]) @parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("dtype", [torch.bfloat16, torch.float16])
def check_ring_attn(seq_len, bs, nheads, d, dtype): def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
torch.cuda.manual_seed(2) torch.cuda.manual_seed(2)
device = get_current_device() device = get_current_device()
sp_group = dist.group.WORLD sp_group = dist.group.WORLD
dp_size, pp_size, tp_size = 1, 1, 1
sp_size = dist.get_world_size() sp_size = dist.get_world_size()
sp_axis = 2
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
# Some outliers may seem large, but our errors are still lower than # Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's # than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
@ -40,11 +44,11 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
q, q,
k, k,
v, v,
sp_group, sp_axis,
AttnMaskType.CAUSAL, AttnMaskType.CAUSAL,
return_softmax=True, return_softmax=True,
inner_ring_size=max(2, sp_size // 2), inner_ring_size=inner_ring_size,
# inner_ring_size=4 pg_mesh=pg_mesh,
) )
ring_out = ring_out.transpose(1, 2) ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func( out, lse, _ = flash_attn_qkvpacked_func(
@ -83,6 +87,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
device = get_current_device() device = get_current_device()
sp_group = dist.group.WORLD sp_group = dist.group.WORLD
sp_size = dist.get_world_size() sp_size = dist.get_world_size()
sp_axis = 2
atol = rtol = 7e-3 atol = rtol = 7e-3
torch.cuda.manual_seed(2) torch.cuda.manual_seed(2)
# Prepare varlen attention mask # Prepare varlen attention mask
@ -123,10 +128,11 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
q_ring, q_ring,
k_ring, k_ring,
v_ring, v_ring,
sp_group, sp_axis,
**mask_info, **mask_info,
pad_output=False, pad_output=False,
return_softmax=True, return_softmax=True,
pg_mesh=ProcessGroupMesh(1, 1, sp_size, 1),
# deterministic=True # deterministic=True
) )
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
@ -161,12 +167,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
def launch_single_ring(rank, world_size, port): def launch_single_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port) colossalai.launch(rank, world_size, "localhost", port)
check_packed_seq() check_packed_seq()
check_ring_attn() check_ring_attn(inner_ring_size=None)
def launch_double_ring(rank, world_size, port): def launch_double_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port) colossalai.launch(rank, world_size, "localhost", port)
check_ring_attn() check_ring_attn(inner_ring_size=2)
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()

View File

@ -1 +1 @@
0.4.4 0.4.5