Merge pull request #6107 from duanjunwen/dev/zero_bubble

[Zerobubble] Merge Main.
feature/zerobubble
duanjunwen 3 weeks ago committed by GitHub
commit 37b23e32b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -15,21 +15,21 @@ repos:
args: ["--profile", "black"] # avoid conflict with black
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0
rev: 24.10.0
hooks:
- id: black
name: black formatter
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
rev: v19.1.2
hooks:
- id: clang-format
name: clang formatter
types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: check-merge-conflict

@ -25,16 +25,36 @@
</div>
## GPU Cloud HPC-AI.COM Coming
For a limited time, you can access an H100 Server for just $1! This is your chance to leverage premium GPU power at an unbeatable price.
Plus, when you refer a friend, youll receive 20% cashback or compute credits equal to 100% of their top-up!
Our platform offers on-demand premium compute, ensuring safe, permanent data storage even after stopping your instance.
Dont miss this incredible opportunity to accelerate your AI projects!
Unlock premium GPUs and register now at [HPC-AI.COM](https://hpc-ai.com) to receive $10!
Special Bonuses:
* Top up $1,000 and receive 300 credits
* Top up $500 and receive 100 credits
<div align="center">
<a href="https://youtu.be/ilMQpU71ddI?si=J4JSPzZ03ycYmlki">
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/HPCAICOM241010.jpg" width="700" />
</a>
</div>
## Latest News
* [2024/10] [How to build a low-cost Sora-like app? Solutions for you](https://company.hpc-ai.com/blog/how-to-build-a-low-cost-sora-like-app-solutions-for-you)
* [2024/09] [Singapore Startup HPC-AI Tech Secures 50 Million USD in Series A Funding to Build the Video Generation AI Model and GPU Platform](https://company.hpc-ai.com/blog/singapore-startup-hpc-ai-tech-secures-50-million-usd-in-series-a-funding-to-build-the-video-generation-ai-model-and-gpu-platform)
* [2024/09] [Reducing AI Large Model Training Costs by 30% Requires Just a Single Line of Code From FP8 Mixed Precision Training Upgrades](https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades)
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
* [2024/03] [Open-SoraSora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
## Table of Contents
<ul>

@ -27,11 +27,11 @@
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [O1 Journey](#o1-journey)
- [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
- [The Plan](#the-plan)
- [Real-time progress](#real-time-progress)
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
- [Quick Preview](#quick-preview)
- [Authors](#authors)
@ -272,7 +272,7 @@ Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pd
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
### Inference Quantization and Serving - After Training
## Inference Quantization and Serving - After Training
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
@ -281,6 +281,21 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen
Online inference server scripts can help you deploy your own services.
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
## O1 Journey
### Inference with Self-refined MCTS
We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search.
To run inference with MCTS, simply use the following script.
```python
from coati.reasoner.guided_search.mcts import MCTS
from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG
problem = "How Many R in 'Strawberry'"
search_tree = MCTS(problem=problem, max_simulations=8, cfg=Qwen32B_prompt_CFG)
answer = search_tree.simulate()
print(answer)
```
## Coati7B examples
### Generation

@ -153,10 +153,11 @@ class DpoLoss(nn.Module):
else:
# If no reference model is provided
ref_logratios = 0.0
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
logits = pi_logratios - ref_logratios - self.gamma / self.beta
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
loss = losses.mean()
# Calculate rewards for logging
if logprob_ref_chosen is not None:
chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
@ -167,7 +168,7 @@ class DpoLoss(nn.Module):
else:
rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()
return losses, chosen_rewards, rejected_rewards
return loss, chosen_rewards, rejected_rewards
class LogSigLoss(nn.Module):

@ -50,8 +50,8 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
torch.Tensor: The log probabilities corresponding to the labels.
"""
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return per_label_logps.squeeze(-1)
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:

@ -0,0 +1,26 @@
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
API_KEY = "Dummy API Key"
def get_client(base_url: str | None = None) -> openai.Client:
return openai.Client(api_key=API_KEY, base_url=base_url)
def chat_completion(
messages: list[ChatCompletionMessageParam],
model: str,
base_url: str | None = None,
temperature: float = 0.8,
**kwargs,
) -> ChatCompletion:
client = get_client(base_url)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
**kwargs,
)
return response

@ -0,0 +1,250 @@
"""
Implementation of MCTS + Self-refine algorithm.
Reference:
1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
2. https://github.com/BrendanGraham14/mcts-llm/
3. https://github.com/trotsky1997/MathBlackBox/
4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
"""
from __future__ import annotations
import math
from collections import deque
import numpy as np
import tqdm
from coati.reasoner.guided_search.llm import chat_completion
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
from pydantic import BaseModel
class MCTSNode(BaseModel):
"""
Node for MCTS.
"""
answer: str
parent: MCTSNode = None
children: list[MCTSNode] = []
num_visits: int = 0
Q: int = 0
rewards: list[int] = []
def expand_node(self, node) -> None:
self.children.append(node)
def add_reward(self, reward: int) -> None:
self.rewards.append(reward)
self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2
class MCTS(BaseModel):
"""
Simulation of MCTS process.
"""
problem: str
max_simulations: int
cfg: PromptCFG
C: float = 1.4
max_children: int = 2
epsilon: float = 1e-5
root: MCTSNode = None
def initialization(self):
"""
Root Initiation.
"""
# Dummy answer as root.
base_answer = self.sample_base_answer()
self.root = MCTSNode(answer=base_answer)
self.self_evaluate(self.root)
def is_fully_expanded(self, node: MCTSNode):
return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children)
def select_node(self) -> MCTSNode:
"""
Select next node to explore.
"""
candidates: list[MCTSNode] = []
to_explore = deque([self.root])
while to_explore:
current_node = to_explore.popleft()
if not self.is_fully_expanded(current_node):
candidates.append(current_node)
to_explore.extend(current_node.children)
if not candidates:
return self.root
return max(candidates, key=self.compute_uct)
def self_evaluate(self, node: MCTSNode):
"""
Sample reward of the answer.
"""
reward = self.sample_reward(node)
node.add_reward(reward)
def back_propagation(self, node: MCTSNode):
"""
Back propagate the value of the refined answer.
"""
parent = node.parent
while parent:
best_child_Q = max(child.Q for child in parent.children)
parent.Q = (parent.Q + best_child_Q) / 2
parent.num_visits += 1
parent = parent.parent
def compute_uct(self, node: MCTSNode):
"""
Compute UCT.
"""
if node.parent is None:
return -100
return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon))
def simulate(self):
self.initialization()
for _ in tqdm.tqdm(range(self.max_simulations)):
node = self.select_node()
child = self.self_refine(node)
node.expand_node(child)
self.self_evaluate(child)
self.back_propagation(child)
return self.get_best_answer()
def get_best_answer(self):
to_visit = deque([self.root])
best_node = self.root
while to_visit:
current_node = to_visit.popleft()
if current_node.Q > best_node.Q:
best_node = current_node
to_visit.extend(current_node.children)
return best_node.answer
def self_refine(self, node: MCTSNode):
"""
Refine node.
"""
critique_response = chat_completion(
messages=[
{
"role": "system",
"content": self.cfg.critic_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<current_answer>\n{node.answer}\n</current_answer>",
]
),
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
critique = critique_response.choices[0].message.content
assert critique is not None
refined_answer_response = chat_completion(
messages=[
{
"role": "system",
"content": self.cfg.refine_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<current_answer>\n{node.answer}\n</current_answer>",
f"<critique>\n{critique}\n</critique>",
]
),
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
refined_answer = refined_answer_response.choices[0].message.content
assert refined_answer is not None
return MCTSNode(answer=refined_answer, parent=node)
def sample_base_answer(self):
response = chat_completion(
messages=[
{
"role": "system",
"content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].",
},
{
"role": "user",
"content": f"<problem>\n {self.problem} \n</problem> \nLet's think step by step",
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
assert response.choices[0].message.content is not None
return response.choices[0].message.content
def sample_reward(self, node: MCTSNode):
"""
Calculate reward.
"""
messages = [
{
"role": "system",
"content": self.cfg.evaluate_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<answer>\n{node.answer}\n</answer>",
]
),
},
]
for attempt in range(3):
try:
response = chat_completion(
messages=messages,
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
assert response.choices[0].message.content is not None
return int(response.choices[0].message.content)
except ValueError:
messages.extend(
[
{
"role": "assistant",
"content": response.choices[0].message.content,
},
{
"role": "user",
"content": "Failed to parse reward as an integer.",
},
]
)
if attempt == 2:
raise

@ -0,0 +1,10 @@
from pydantic import BaseModel
class PromptCFG(BaseModel):
model: str
base_url: str
max_tokens: int = 4096
critic_system_prompt: str
refine_system_prompt: str
evaluate_system_prompt: str

@ -0,0 +1,20 @@
"""
Prompts for Qwen Series.
"""
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
Qwen32B_prompt_CFG = PromptCFG(
base_url="http://0.0.0.0:8008/v1",
model="Qwen2.5-32B-Instruct",
critic_system_prompt="Provide a detailed and constructive critique to improve the answer. "
"Highlight specific areas that need refinement or correction.",
refine_system_prompt="""# Instruction
Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer].
""",
evaluate_system_prompt=(
"Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. "
"Do not give a full score above 95. Make sure the reward score is an integer. "
"Return *ONLY* the score."
),
)

@ -6,6 +6,7 @@ import os
from typing import Any, Optional
import torch
import torch.distributed as dist
from coati.models.loss import DpoLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
@ -13,10 +14,11 @@ from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import trange
from tqdm import tqdm, trange
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster, Plugin
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
@ -96,18 +98,25 @@ class DPOTrainer(SLTrainer):
self.train_dataloader = train_preference_dataloader
self.eval_dataloader = eval_preference_dataloader
self.writer = None
if use_wandb and is_rank_0():
init_criterion = (
dist.get_rank() == dist.get_world_size() - 1
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1
else is_rank_0()
)
if use_wandb and init_criterion:
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
if log_dir is not None and init_criterion:
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "dpo")
log_dir = os.path.join(log_dir, "DPO")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
@ -117,166 +126,147 @@ class DPOTrainer(SLTrainer):
epoch int: the number of current epoch
"""
self.model.train()
self.accumulative_meter.reset()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
step_bar = tqdm(
range(len(self.train_dataloader)),
desc="Step",
disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
if self.ref_model is not None:
self.ref_model.eval()
with torch.no_grad():
ref_all_logits = self.ref_model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
batch_size = chosen_input_ids.size()[0]
# Calculate logits from reference model.
if self.ref_model is not None:
self.ref_model.eval()
with torch.no_grad():
ref_all_logits = self.ref_model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
else:
logprob_ref_chosen = None
logprob_ref_reject = None
# Merge chosen and reject
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
attention_mask = torch.stack(
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
)
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
logprob_ref = torch.stack([item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup])
data_iter = iter(
[
{
"input_ids": inputs_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"logprob_ref": logprob_ref,
}
]
)
rewards = []
def _criterion(outputs, inputs):
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
calc_masked_log_probs(
outputs["logits"][0::2],
inputs["input_ids"][0::2],
inputs["loss_mask"][0::2][:, 1:],
self.length_normalization,
),
calc_masked_log_probs(
outputs["logits"][1::2],
inputs["input_ids"][1::2],
inputs["loss_mask"][1::2][:, 1:],
self.length_normalization,
),
inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
inputs["loss_mask"][0::2][:, 1:],
inputs["loss_mask"][1::2][:, 1:],
)
else:
logprob_ref_chosen = None
logprob_ref_reject = None
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen,
logprob_actor_reject,
logprob_ref_chosen if logprob_ref_chosen is not None else None,
logprob_ref_reject if logprob_ref_reject is not None else None,
chosen_loss_mask[:, 1:],
reject_loss_mask[:, 1:],
)
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
# DPO Loss
loss = losses.mean()
rewards.append(chosen_rewards)
rewards.append(rejected_rewards)
return loss
outputs = self.booster.execute_pipeline(
data_iter,
self.model,
criterion=_criterion,
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
global_loss = all_reduce_mean(loss, self.plugin)
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix(
{
"train/loss": global_loss.item(),
"train/lr": self.actor_scheduler.get_last_lr()[0],
"train/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
"train/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
}
)
step_bar.update()
self.accumulative_meter.add("loss", global_loss.item())
self.accumulative_meter.add("chosen_rewards", chosen_rewards.to(torch.float16).mean().item())
self.accumulative_meter.add(
"rejected_rewards", rejected_rewards.to(torch.float16).mean().item()
)
if self.writer is not None:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), i)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), i
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
i,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards")
- self.accumulative_meter.get("rejected_rewards"),
i,
)
self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
step_bar.update()
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
)
self.accumulative_meter.reset()
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=i + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
self.ref_model.eval()
self.coordinator.print_on_master("\nStart evaluation...")
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
self.accumulative_meter.reset()
with torch.no_grad():
for i, batch in enumerate(self.eval_dataloader):
else:
self.accumulative_meter.reset()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
@ -300,12 +290,11 @@ class DPOTrainer(SLTrainer):
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
@ -314,22 +303,26 @@ class DPOTrainer(SLTrainer):
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
self.ref_model.eval()
ref_all_logits = self.ref_model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
if self.ref_model is not None:
self.ref_model.eval()
with torch.no_grad():
ref_all_logits = self.ref_model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
else:
logprob_ref_chosen = None
logprob_ref_reject = None
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen,
logprob_actor_reject,
logprob_ref_chosen if logprob_ref_chosen is not None else None,
@ -338,7 +331,9 @@ class DPOTrainer(SLTrainer):
reject_loss_mask[:, 1:],
)
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
loss = losses.mean()
self.booster.backward(loss=loss, optimizer=self.optimizer)
# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
@ -347,16 +342,301 @@ class DPOTrainer(SLTrainer):
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
self.accumulative_meter.add(
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
step_bar.set_postfix(
{
"train/loss": self.accumulative_meter.get("loss"),
"train/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
"train/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
"train/accuracy": self.accumulative_meter.get("accuracy"),
}
)
step_bar.update()
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards")
- self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
)
self.num_train_step += 1
self.accumulative_meter.reset()
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=self.num_train_step,
batch_size=batch_size,
coordinator=self.coordinator,
)
step_bar.update()
msg = "Evaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
self.ref_model.eval()
self.accumulative_meter.reset()
self.coordinator.print_on_master("\nStart evaluation...")
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
step_bar = tqdm(
range(len(self.eval_dataloader)),
desc="Step",
disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
with torch.no_grad():
for _, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
batch_size = chosen_input_ids.size()[0]
# Calculate logits from reference model.
if self.ref_model is not None:
self.ref_model.eval()
with torch.no_grad():
ref_all_logits = self.ref_model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
else:
logprob_ref_chosen = None
logprob_ref_reject = None
# Merge chosen and reject
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
attention_mask = torch.stack(
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
)
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
logprob_ref = torch.stack(
[item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup]
)
data_iter = iter(
[
{
"input_ids": inputs_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"logprob_ref": logprob_ref,
}
]
)
rewards = []
def _criterion(outputs, inputs):
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
calc_masked_log_probs(
outputs["logits"][0::2],
inputs["input_ids"][0::2],
inputs["loss_mask"][0::2][:, 1:],
self.length_normalization,
),
calc_masked_log_probs(
outputs["logits"][1::2],
inputs["input_ids"][1::2],
inputs["loss_mask"][1::2][:, 1:],
self.length_normalization,
),
inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
inputs["loss_mask"][0::2][:, 1:],
inputs["loss_mask"][1::2][:, 1:],
)
rewards.append(chosen_rewards)
rewards.append(rejected_rewards)
return loss
outputs = self.booster.execute_pipeline(
data_iter,
self.model,
criterion=_criterion,
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
global_loss = all_reduce_mean(loss, self.plugin)
chosen_rewards_mean = all_reduce_mean(chosen_rewards, self.plugin)
rejected_rewards_mean = all_reduce_mean(rejected_rewards, self.plugin)
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix(
{
"eval/loss": global_loss.item(),
"eval/lr": self.actor_scheduler.get_last_lr()[0],
"eval/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
"eval/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
}
)
self.accumulative_meter.add(
"chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()
)
self.accumulative_meter.add(
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
)
self.accumulative_meter.add("loss", global_loss.to(torch.float16).item())
step_bar.update()
if self.booster.plugin.stage_manager.is_last_stage():
msg = "\nEvaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
if dist.get_rank() == dist.get_world_size() - 1:
print(msg)
else:
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
with torch.no_grad():
for i, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
ref_all_logits = self.ref_model(
torch.cat([chosen_input_ids, reject_input_ids]),
torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen,
logprob_actor_reject,
logprob_ref_chosen if logprob_ref_chosen is not None else None,
logprob_ref_reject if logprob_ref_reject is not None else None,
chosen_loss_mask[:, 1:],
reject_loss_mask[:, 1:],
)
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
loss = losses.mean()
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add(
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
self.accumulative_meter.add(
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
)
step_bar.set_postfix(
{
"eval/loss": self.accumulative_meter.get("loss"),
"eval/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
"eval/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
"eval/accuracy": self.accumulative_meter.get("accuracy"),
}
)
step_bar.update()
msg = "\nEvaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()

@ -73,8 +73,7 @@ def main():
"--conversation_template_config",
type=str,
default="conversation_template_config",
help="Path \
to save conversation template config files.",
help="Path to save conversation template config files.",
)
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
parser.add_argument(

@ -13,7 +13,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -29,8 +29,6 @@ def train(args):
# check lora compatibility
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ==============================
# Initialize Distributed Training
@ -46,7 +44,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True)
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
@ -56,14 +54,6 @@ def train(args):
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
@ -92,20 +82,24 @@ def train(args):
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# Temp Fix: Disable lazy init due to version conflict
# init_ctx = (
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
ref_plugin = HybridParallelPlugin(
tp_size=args.ref_tp,
pp_size=1,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
ref_booster = Booster(plugin=ref_plugin)
init_ctx = nullcontext()
with init_ctx:
@ -130,6 +124,7 @@ def train(args):
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
else:
ref_model = None
if args.lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
@ -139,7 +134,9 @@ def train(args):
disable_dropout(ref_model)
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
# Make sure gradient checkpointing can be activated.
model.train()
# Note, for some models, lora may not be compatible with gradient checkpointing.
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
@ -169,7 +166,7 @@ def train(args):
adamw_mode=True,
)
# configure dataset
# Configure dataset
coordinator.print_on_master(f"Load dataset: {args.dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
@ -213,14 +210,15 @@ def train(args):
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
dataloader=train_dataloader,
)
if ref_model is not None:
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
@ -312,7 +310,7 @@ if __name__ == "__main__":
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
choices=["gemini", "zero2", "zero2_cpu", "3d", "ddp"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
@ -342,22 +340,35 @@ if __name__ == "__main__":
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument(
"--disable_reference_model",
action="store_true",
default=False,
help="Disable the reference model (enabled by default)",
)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument(
"--microbatch_size",
type=int,
default=2,
help="Micro batch size for PP training. To activate PP training for DPO-like algorithm, you must keep size even and the size should be equal or greater than 2.",
)
# Parameter for reference model
parser.add_argument(
"--disable_reference_model",
action="store_true",
default=False,
help="Disable the reference model (enabled by default)",
)
parser.add_argument(
"--ref_tp",
type=int,
default=1,
help="TP size for reference model; used only when reference model is too large.",
)
args = parser.parse_args()
# fool proof hyperparameter setup

@ -68,7 +68,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False)
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,

@ -4,7 +4,7 @@ BASE_TEMP_DIR=$BASE_DIR/temp
EXAMPLES_DIR=$BASE_DIR/examples
TEST_DATA_DIR=$BASE_DIR/tests/test_data
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
CONFIG_DIR=$BASE_DIR/config
CONFIG_DIR=$BASE_DIR/conversation_template
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
@ -39,23 +39,23 @@ get_pretrain() {
get_conversation_template_config() {
local model=$1
if [[ $model == "colossal-llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
echo "$CONFIG_DIR/colossal-llama2.json"
elif [[ $model == "llama2" ]]; then
echo "$CONFIG_DIR/conversation_template/llama2.json"
echo "$CONFIG_DIR/llama2.json"
elif [[ $model == "deepseek" ]]; then
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
echo "$CONFIG_DIR/deepseek-ai_DeepSeek-V2-Lite.json"
elif [[ $model == "mistral" ]]; then
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
echo "$CONFIG_DIR/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
elif [[ $model == "chatGLM2" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
echo "$CONFIG_DIR/THUDM_chatglm2-6b.json"
elif [[ $model == "chatGLM3" ]]; then
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
echo "$CONFIG_DIR/THUDM_chatglm3-6b.json"
elif [[ $model == "phi" ]]; then
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
echo "$CONFIG_DIR/microsoft_phi-2.json"
elif [[ $model == "Yi" ]]; then
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
echo "$CONFIG_DIR/01-ai_Yi-1.5-9B-Chat.json"
elif [[ $model == "baichuan" ]]; then
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
echo "$CONFIG_DIR/baichuan-inc_Baichuan2-13B-Chat.json"
else
echo "Unknown model $model"
exit 1
@ -71,6 +71,7 @@ for model in ${MODELS[@]}; do
rm -rf $SAVE_DIR/arrow
pretrain=$(get_pretrain $model)
conversation_template_config=$(get_conversation_template_config $model)
echo $conversation_template_config
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
--tokenizer_dir $pretrain \
--conversation_template_config $conversation_template_config \

@ -279,4 +279,4 @@ class CudaAccelerator(BaseAccelerator):
"""
Return autocast function
"""
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)

@ -322,7 +322,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
@ -366,7 +365,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
verbose: bool = False,
@ -428,7 +426,6 @@ class GeminiPlugin(DPPluginBase):
self.enable_flash_attention = enable_flash_attention
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose
self.tp_size = tp_size
@ -455,7 +452,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=self.enable_sequence_parallelism,
enable_sequence_overlap=self.enable_sequence_overlap,
)
def __del__(self):

@ -116,10 +116,15 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
super().__init__(module)
self.op_hooks = []
if use_fp8:
self.op_hooks.append(FP8Hook())
self.op_hooks = []
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather:
self.op_hooks.append(ZeroOpHook())
if use_fp8 or overlap_allgather:
self.op_hooks.append(ZeroOpHook())
if use_fp8 or overlap_allgather:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
@ -232,6 +237,9 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
@ -951,7 +959,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
@ -983,6 +990,8 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
@ -1002,7 +1011,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
@ -1092,6 +1100,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.use_fp8 = use_fp8
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
if sequence_parallelism_mode == "ring_attn":
# Swap tp and sp since 2D Ring has better inter-node latency
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
@ -1195,13 +1204,15 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh,
sp_axis=self.sp_axis,
)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
@ -1293,6 +1304,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.dp_size == 1 and self.pp_size == 1
)
# sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])

@ -290,7 +290,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class LowLevelZeroPlugin(DPPluginBase):

@ -141,7 +141,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
@ -190,7 +189,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
@ -368,7 +366,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,

@ -1,9 +1,11 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
@ -134,7 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class TorchDDPModel(ModelWrapper):

@ -11,6 +11,7 @@ import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -20,7 +21,7 @@ from colossalai.tensor.padded_tensor import (
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import get_current_device
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@ -104,8 +105,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
yield block, block_size
# Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in model._non_persistent_buffers_set:
if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
@ -351,9 +353,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
_load(name)
# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
non_persistent_buffers = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)
@ -956,4 +956,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)

@ -57,7 +57,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
suffix="self_attn.W_pack",
target_module=FusedLinear1D_Col,
kwargs={"split_sizes": [self.model.config.hidden_size] * 3},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",

@ -1,7 +1,6 @@
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl
@ -45,6 +44,7 @@ def warmup_jit_fusion(
dtype: torch.dtype = torch.float32,
):
"""Compile JIT functions before the main training steps"""
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())

@ -3,8 +3,9 @@ from typing import Any, List, Optional, Tuple
import torch
import torch.cuda
from packaging.version import Version
from torch.nn import Module
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten
# this register are for torch under version 1.13.1, maybe removed in the future
@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
if Version(torch.__version__) <= Version("1.13.1"):
try:
from torch.utils._pytree import register_pytree_node as _register_pytree_node
except ImportError:
from torch.utils._pytree import _register_pytree_node
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
def tree_map_hf(fn: Any, pytree: Any):

@ -351,15 +351,16 @@ class InterleavedSchedule(PipelineSchedule):
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
if "backward_tensor_keys" not in output_obj:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
tensors_to_backward = []
grads_to_backward = []
for k in keys:
tensors_to_backward.append(output_obj[k])
grads_to_backward.append(output_obj_grad[k])
if len(tensors_to_backward) == 1:
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
else:
for k, grad in output_obj_grad.items():
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
# Collect the grad of the input_obj.
input_obj_grad = None

@ -305,15 +305,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
if "backward_tensor_keys" not in output_obj:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
tensors_to_backward = []
grads_to_backward = []
for k in keys:
tensors_to_backward.append(output_obj[k])
grads_to_backward.append(output_obj_grad[k])
if len(tensors_to_backward) == 1:
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
else:
for k, grad in output_obj_grad.items():
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
# Collect the grad of the input_obj.
input_obj_grad = None

@ -8,6 +8,8 @@ import torch.nn.functional as F
from packaging.version import Version
from torch.distributed import ReduceOp
from .fp8_config import dynamic_kernel
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
SCALE_BYTES = 4
try:
@ -832,11 +834,13 @@ class _LinearFp8(torch.autograd.Function):
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False)
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0:
return F.linear(input, weight, bias)
out = _linear_fp8(input, weight, bias)
return out

@ -0,0 +1 @@
dynamic_kernel: bool = False

@ -6,7 +6,7 @@ from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHe
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
__all__ = [
"Embedding1D",
@ -35,4 +35,5 @@ __all__ = [
"RingAttention",
"get_pad_info",
"all_to_all_comm",
"FusedLinear1D_Row",
]

@ -106,7 +106,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce and fp8_communication:
if fp8_communication or not ctx.async_grad_allreduce:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
elif ctx.async_grad_allreduce:
# Asynchronous all-reduce
@ -364,10 +364,12 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
for k in recv_tensors:
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
input_tensors = []
output_tensors = []
handles = communicate_step()
# first round: special case, retrive from local tensor
input_tensors.append(input_to_gather)
output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2):
for handle in handles:
@ -378,14 +380,25 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
handles = communicate_step()
# actual computation
input_tensors.append(send_tensors)
output_tensors.append(func(**send_tensors, **input_local))
# final round: special case, no need to send/recv again
for handle in handles:
handle.wait()
input_tensors.append(send_tensors)
output_tensors.append(func(**recv_tensors, **input_local))
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
gathered_input = {}
for k in input_to_gather:
input_shards = [d[k] for d in input_tensors[group_size - cur_rank :] + input_tensors[: group_size - cur_rank]]
gathered_input[k] = torch.cat(input_shards, dim=gather_dim)
gathered_output = torch.cat(
output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim
)
return gathered_output, gathered_input
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
@ -441,29 +454,30 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
if ring is True:
input_to_gather = {"input": input_}
input_local = {"weight": weight}
output = _ring_as_gather(
output, input_dict = _ring_as_gather(
F.linear,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
)
ctx.gathered_input = input_dict["input"]
if bias is not None:
output += bias
else:
input_parallel = _gather(input_, dim, process_group)
ctx.gathered_input = input_parallel
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
@ -477,100 +491,50 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
if not overlap:
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
input_parallel = ctx.gathered_input
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
else:
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input_parallel)
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input_parallel)
# grad_weight = grad_output.t().matmul(input_parallel)
# wait until reduce-scatter finished
reducescatter_handle.wait()
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
return output, grad_weight, grad_bias, None, None, None, None, None
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
def _ring_as_reducescatter(
@ -701,7 +665,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
total_input = total_input.reshape(-1, total_input.shape[-1])
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
@ -759,34 +723,30 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
ctx.fp8_communication = fp8_communication
if ring is True:
input_to_gather = {}
input_local = {}
input_to_gather["input"] = input_
input_local["other"] = weight
input_to_gather = {"input": input_}
input_local = {"other": weight}
output = _ring_as_gather(
output, input_dict = _ring_as_gather(
torch.matmul,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
gather_dim=dim,
)
ctx.gathered_input = input_dict["input"]
else:
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
ctx.gathered_input = input_parallel
output = torch.matmul(input_parallel, weight)
if bias is not None:
@ -799,76 +759,39 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
if use_bias:
bias = bias.view(bias.shape)
if not overlap:
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
input_parallel = ctx.gathered_input
else:
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return output, grad_weight, grad_bias, None, None, None, None, None, None
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
@ -988,7 +911,7 @@ class _AllToAll(torch.autograd.Function):
ctx.gather_dim = gather_dim
ctx.fp8_communication = fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape
bsz = input_.shape[0]
# using all_to_all_single when batch size is 1
if bsz == 1:
@ -1019,7 +942,7 @@ class _AllToAll(torch.autograd.Function):
gather_dim = ctx.scatter_dim
fp8_communication = ctx.fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = grad_output.shape
bsz = grad_output.shape[0]
if bsz == 1:
return_grad = _all_to_all_single(
@ -1204,10 +1127,10 @@ def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=F
def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
):
return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
)
@ -1224,10 +1147,10 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
)

@ -422,16 +422,21 @@ class RingAttention(torch.autograd.Function):
ATTN_DONE: torch.cuda.Event = None
SP_STREAM: torch.cuda.Stream = None
SP_GROUP: dist.ProcessGroup = None
# duplicate process group for concurrent NCCL streams
# while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
# against this, in practice it seems to work fine.
# NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput,
# both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
# LoongTrain's original double ring impl. uses concurrent PGs
# (https://github.com/InternLM/InternEvo/blob/e52f2ffc9acf818e8f2b1f97dfc69ceb2f06e154/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L192)
# but I confirmed with Pytorch developers this can cause obscure "Software caused connection abort" errors.
# (https://github.com/pytorch/pytorch/issues/132852)
# NOTE: In general, a smarter idea is put as many P2P calls as possible into one `batch_isend_irecv`.
INNER_RING_GROUP: dist.ProcessGroup = None
INNER_RING_GROUP_COPY: dist.ProcessGroup = None
# INNER_RING_GROUP_COPY: dist.ProcessGroup = None
INTER_RING_GROUP: dist.ProcessGroup = None
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
# INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None):
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
"""
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node.
@ -441,21 +446,17 @@ class RingAttention(torch.autograd.Function):
Returns:
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
"""
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
sp_group = pg_mesh.get_group_along_axis(sp_axis)
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if inner_ring_size is None:
if torch.cuda.device_count() >= dist.get_world_size():
# single node, no need to consider NICs
return sp_group, sp_group
if sp_size <= 4:
inner_ring_size = min(2, sp_size)
else:
inner_ring_size = min(4, sp_size)
else:
assert (
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
assert inner_ring_size is not None
assert (
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
if inner_ring_size == sp_size:
return sp_group, sp_group
@ -474,14 +475,14 @@ class RingAttention(torch.autograd.Function):
# Create inner ring groups
for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = dist.new_group(ranks)
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks:
inner_ring_group = group
# Create inter ring groups
for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings))
group = dist.new_group(ranks)
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks:
inter_ring_group = group
@ -492,7 +493,7 @@ class RingAttention(torch.autograd.Function):
q, # (B, H, Sq, D)
k,
v,
sp_group,
sp_axis,
attention_mask_type,
cu_seqlens=None,
max_seqlen=None,
@ -502,6 +503,7 @@ class RingAttention(torch.autograd.Function):
deterministic=False,
return_softmax=False,
inner_ring_size=None,
pg_mesh=None,
**kwargs,
):
"""
@ -512,7 +514,7 @@ class RingAttention(torch.autograd.Function):
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism
sp_axis (Optional[int]): Sp axis for the global pg mesh.
sp_tream (torch.cuda.Stream): An different stream for output correction.
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q.
@ -537,7 +539,6 @@ class RingAttention(torch.autograd.Function):
RingAttention.ATTN_DONE = torch.cuda.Event()
if RingAttention.SP_STREAM is None:
RingAttention.SP_STREAM = torch.cuda.Stream()
assert (
q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\
@ -546,11 +547,13 @@ class RingAttention(torch.autograd.Function):
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
), f"Mask type {attention_mask_type} is not supported yet."
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
if RingAttention.SP_GROUP is not sp_group:
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
sp_group = pg_mesh.get_group_along_axis(sp_axis)
if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size)
RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group
else:
@ -628,7 +631,13 @@ class RingAttention(torch.autograd.Function):
inner_ring_group: Optional[dist.ProcessGroup] = None,
inter_ring_group: Optional[dist.ProcessGroup] = None,
):
"""
Forward supporting both packed (varlen) and batched(fixed length, no padding) sequences.
No separate version for batched seq (hard to maintain), which incurs
some overhead in sequence splitting due to python for loops.
Uses two CUDA streams to overlap softmax denominator correction with next flash attn
(see comments below).
"""
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
max_seqlen_q = max_seqlen_kv = max_seqlen
cu_seqlens_half = cu_seqlens // 2
@ -670,7 +679,8 @@ class RingAttention(torch.autograd.Function):
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
# Attempt to achieve concurrent comm in the two-stream forward
# Create communicators corresponding to two CUDA streams
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
inter_ring_comm = RingComm(inter_ring_group)
local_sp_size = dist.get_world_size(inner_ring_group)
@ -678,7 +688,7 @@ class RingAttention(torch.autograd.Function):
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
# Non-contiguous indexing copies to a new contiguous tensor,
# Any type of indexing(but not slicing) copies to a new contiguous tensor,
# so only do it once
if sp_rank != sp_size - 1:
q1 = q[half_idx_back]
@ -695,6 +705,7 @@ class RingAttention(torch.autograd.Function):
rng_states = [None for _ in range(sp_size)]
sp_streams = [torch.cuda.current_stream(), sp_stream]
# Helper to pass args to FA
def _forward(q, k, v, causal):
(
_,
@ -725,6 +736,7 @@ class RingAttention(torch.autograd.Function):
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
# Forward within a node
def _local_ring_forward():
# (Hopefully) overlap output correction with next flash attn
for i in range(local_sp_size):
@ -733,6 +745,8 @@ class RingAttention(torch.autograd.Function):
# NOTE: waiting outside the current stream will NOT correctly synchronize.
if i > 0:
local_kv_comms[(i + 1) % 2].wait()
# Prefetch
if i == 0:
_kv_comm(i)
@ -766,15 +780,22 @@ class RingAttention(torch.autograd.Function):
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()
# Pipeline the next KV comm with output correction instead of the next flash attn
# to minimize idle time when comm takes longer than attn.
# kernel, to minimize bubble when comm takes longer than attn.
_kv_comm(i + 1)
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
) # (H, T) -> (T, H, 1)
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
# In reality this always finishes before next flash attn; no need for extra sync.
# Output and log sum exp correction.
# Ideally overlap this with the next flash attn kernel,
# since attn uses Tensor Core and rescale is element-wise, memory-bound and uses CUDA cores.
# (NOTE that this is the same as ping-pong scheduling idea in FA3)
# TODO However sometimes while the GPU has scheduled the next kernel,
# it's reluctant to launch it in overlap. Some potential causes:
# 1. need lower-level CUDA scheduling 2. further benchmark against Megatron-LM
# 3. register spilling by FA kernel.
if i == 0:
out = block_out[0]
softmax_lse = block_softmax_lse[0]
@ -790,15 +811,17 @@ class RingAttention(torch.autograd.Function):
torch.cuda.current_stream().wait_stream(sp_stream)
return out, softmax_lse
# Forward for inter-node (the outer ring in 2D ring)
def _other_ring_forward(ring_num_idx, out, softmax_lse):
# Loop through the inner ring after receiving
# all new KVs from the previous inner ring
# all new KVs from another ring
for i in range(local_sp_size):
with torch.cuda.stream(sp_streams[i % 2]):
# Send & recv KV
if i > 0:
local_kv_comms[(i + 1) % 2].wait()
# Prefetch
if i == 0:
_kv_comm(i)
@ -895,7 +918,8 @@ class RingAttention(torch.autograd.Function):
def backward(ctx, dout, _):
"""
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
over all ranks for accumulation.
over all ranks for accumulation. We avoid using two streams due to backward using doubled
buffers and more comm cost.
"""
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
rng_states = ctx.saved_tensors[9:]
@ -927,7 +951,7 @@ class RingAttention(torch.autograd.Function):
local_sp_rank = dist.get_rank(sp_group)
sp_size = dist.get_world_size(sp_group)
# Using separate streams (pg) for concurrent kv and dkv comm may
# NOTE: Using separate streams (PG) for concurrent kv and dkv comm may
# cause NCCL "software caused connection abort" here...
local_kv_comm = RingComm(local_kv_group)
local_dkv_comm = RingComm(local_kv_group)
@ -959,6 +983,7 @@ class RingAttention(torch.autograd.Function):
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
del k, v
# Helper to pass args to FA
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
_flash_attn_backward(
dout,
@ -979,8 +1004,7 @@ class RingAttention(torch.autograd.Function):
**misc_kwargs,
)
# NOTE: We avoid using two streams due to doubled buffers
# and that backward is more communication intensive.
# Backward within a node
def _local_ring_backward():
for i in range(local_sp_size):
if i > 0:
@ -1043,6 +1067,7 @@ class RingAttention(torch.autograd.Function):
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
return dq, dkv_recv, dkv_send
# Backward for inter-node (the outer ring in 2D ring)
def _other_ring_backward(ring_num_idx, dq):
if ring_num_idx > inter_ring_rank:
# Indexing is expensive
@ -1127,34 +1152,34 @@ class RingAttention(torch.autograd.Function):
@staticmethod
def prepare_varlen_batch(
attention_mask: torch.Tensor,
padding_mask: torch.Tensor,
sp_group: dist.ProcessGroup,
inputs_embeds: torch.Tensor = None,
position_ids: Optional[torch.Tensor] = None,
is_label: bool = False,
is_2d: bool = True,
is_batched_seq: bool = True,
):
# TODO: support setting a batch dim (fix packing length) for packed mode, so that
# DP can be used (needs to modify dataloader too)
"""
Preprocess a batch of padded sequence by splitting input sequence by sp_size
sequence-wise and packing them into one sequence. Updates the mask info accordingly.
seq-wise and packing them into one sequence. Updates the mask info accordingly.
Args:
attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
padding_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
sp_group (dist.ProcessGroup): Process group for sequence parallelism
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
token of each sequence.
is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
is_batched_seq (bool, optional): If True, then the input is a batch of (potentially padded) sequences
of shape [B, Sq, ...]; else a packed sequence of shape [T, ...].
Returns:
torch.Tensor:
Packed input embeddings of shape [B, Sq // sp_size, ...].
Dict[str, Any]:
inputs_embeds (torch.Tensor):
Packed input embeddings of shape [B, Sq // sp_size, ...] if is_batched_seq, else [T, ...].
mask_info (Dict[str, Any]):
A dictionary containing mask info.
torch.Tensor:
position_ids (torch.Tensor):
Packed position ids of shape [..., Sq // sp_size].
"""
@ -1162,12 +1187,11 @@ class RingAttention(torch.autograd.Function):
sp_size = dist.get_world_size(group=sp_group)
sp_rank = dist.get_rank(group=sp_group)
mask_info = {}
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False)
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(padding_mask, return_indices=False)
# Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size)
# Split mask to compute local nonzero position indices
# Unpad, split seq-wise, then pad to (B, max_seqlen // sp_size)
# (B, Sq) -> (B, max_seqlen // sp_size)
attention_mask = attention_mask[:, : mask_info["max_seqlen"]]
padding_mask = padding_mask[:, : mask_info["max_seqlen"]]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
inputs_embeds = split_varlen_zigzag(
@ -1175,11 +1199,12 @@ class RingAttention(torch.autograd.Function):
mask_info["cu_seqlens"],
sp_group,
mask_info["max_seqlen"],
is_2d=is_2d,
is_batched_seq=is_batched_seq,
is_label=is_label,
)
attention_mask = split_varlen_zigzag(
attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d
# Split mask to get local nonzero seq positions
padding_mask = split_varlen_zigzag(
padding_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_batched_seq=is_batched_seq
)
if position_ids is not None:
@ -1192,7 +1217,7 @@ class RingAttention(torch.autograd.Function):
)
mask_info["max_seqlen"] //= sp_size
mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
mask_info["valid_indices"] = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
mask_info["cu_seqlens"] //= sp_size
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
return inputs_embeds, mask_info, position_ids

@ -23,18 +23,16 @@ from colossalai.tensor.d_tensor.api import (
)
from ._operation import (
gather_forward_reducescatter_backward,
gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
linear_with_grad_accum,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
@ -197,7 +195,6 @@ class Linear1D_Col(ParallelModule):
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
@ -220,7 +217,6 @@ class Linear1D_Col(ParallelModule):
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
@ -238,7 +234,6 @@ class Linear1D_Col(ParallelModule):
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
@ -345,22 +340,16 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(
if is_share_sp_tp(self.seq_parallel_mode):
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel,
self.weight,
bias,
self.process_group,
False,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)
elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = linear_with_async_comm(
@ -584,31 +573,17 @@ class Linear1D_Row(ParallelModule):
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
if self.seq_parallel_mode is None:
output_parallel = linear_with_async_comm(
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
if is_share_sp_tp(self.seq_parallel_mode):
output = linear_reducescatter_forward_gather_backward(
input_,
self.weight,
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=True,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = linear_with_async_comm(
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
)
output = reduce_forward(output_parallel, self.process_group)
output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
if not self.skip_bias_add:
if self.bias is not None:
@ -716,7 +691,6 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):

@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@ -24,17 +25,17 @@ from colossalai.tensor.d_tensor.api import (
)
from ._operation import (
gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm,
reduce_backward,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset
from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
@ -44,21 +45,25 @@ __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col"
def split_fused_qkv_in_gpt2_style(
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
):
"""
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
Args:
qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
split_sizes (List[int]): The sizes of the split tensor.
process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
"""
# get the number of slice for the fused qkv
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
order = torch.arange(world_size * n_fused)
order = torch.arange(world_size * len(split_sizes))
new_split_sizes = []
for sz in split_sizes:
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
new_split_sizes.extend([sz // world_size] * world_size)
# split the fused qkv
# from
@ -66,9 +71,9 @@ def split_fused_qkv_in_gpt2_style(
# to
# [Q1, Q2, K1, K2, V1, V2]
if is_transposed:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
weight_chunks = torch.split(qkv, new_split_sizes, dim=-1)
else:
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0)
weight_chunks = torch.split(qkv, new_split_sizes, dim=0)
# rearrange the slice into the final order
# from
@ -85,18 +90,23 @@ def split_fused_qkv_in_gpt2_style(
def gather_fused_qkv_in_gpt2_style(
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
):
"""
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
Args:
qkv (torch.Tensor): The fused qkv tensor.
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
split_sizes (List[int]): The sizes of the split tensor.
process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
"""
world_size = dist.get_world_size(group=process_group)
new_split_sizes = []
for sz in split_sizes:
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
new_split_sizes.append(sz // world_size)
new_split_sizes = new_split_sizes * world_size
# gather the tensors
# from
@ -121,13 +131,13 @@ def gather_fused_qkv_in_gpt2_style(
# to
# [Q1, Q2, K1, K2, V1, V2]
if is_transposed:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
weight_chunks = torch.split(gather_weight, new_split_sizes, dim=-1)
else:
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0)
weight_chunks = torch.split(gather_weight, new_split_sizes, dim=0)
reordered_chunk_list = []
for i in range(n_fused):
reordered_chunk_list.extend(weight_chunks[i::n_fused])
for i in range(len(split_sizes)):
reordered_chunk_list.extend(weight_chunks[i :: len(split_sizes)])
if is_transposed:
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
@ -136,6 +146,42 @@ def gather_fused_qkv_in_gpt2_style(
return reordered_gather_weight
class _SplitForwardGatherBackwardFusedQKV(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
ctx.split_sizes = split_sizes
ctx.process_group = process_group
return split_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
@staticmethod
def backward(ctx, grad_output):
grad_output = gather_fused_qkv_in_gpt2_style(
grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True
)
return grad_output, None, None
def split_forward_gather_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
return _SplitForwardGatherBackwardFusedQKV.apply(qkv, split_sizes, process_group)
class _GatherForwardSplitBackwardFusedQKV(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
ctx.split_sizes = split_sizes
ctx.process_group = process_group
return gather_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
@staticmethod
def backward(ctx, grad_output):
grad_output = split_fused_qkv_in_gpt2_style(grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True)
return grad_output, None, None
def gather_forward_split_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
return _GatherForwardSplitBackwardFusedQKV.apply(qkv, split_sizes, process_group)
class GPT2FusedLinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
@ -145,10 +191,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
@ -169,16 +215,14 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self,
in_features: int,
out_features: int,
split_sizes: List[int],
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False,
seq_parallel_mode: str = None,
overlap: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@ -192,14 +236,16 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
self.split_sizes = split_sizes
self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == out_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -223,10 +269,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.weight = weight
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
@ -252,7 +298,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]],
split_sizes: List[int],
*args,
**kwargs,
) -> ParallelModule:
r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
@ -260,7 +310,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
Args:
module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
# get the attributes
@ -291,6 +341,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
split_sizes=split_sizes,
*args,
**kwargs,
)
@ -313,19 +364,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather":
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel,
self.weight,
bias,
self.process_group,
True,
1,
self.overlap,
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode == "ring":
if is_share_sp_tp(self.seq_parallel_mode):
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel,
@ -334,19 +373,18 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.process_group,
True,
1,
self.overlap,
True,
ring=self.seq_parallel_mode == "ring",
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
input_parallel = input_
output_parallel = matmul_with_async_comm(
input_parallel,
self.weight,
bias,
self.process_group,
self.async_communication,
True,
fp8_communication=self.fp8_communication,
)
else:
@ -354,9 +392,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
else:
output = output_parallel
@ -565,7 +601,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
elif is_share_sp_tp(self.seq_parallel_mode):
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel,
@ -573,13 +609,6 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
1,
self.fp8_communication,
)
elif self.seq_parallel_mode == "ring":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel,
self.process_group,
1,
)
else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
@ -605,10 +634,10 @@ class FusedLinear1D_Col(ParallelModule):
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
@ -628,14 +657,15 @@ class FusedLinear1D_Col(ParallelModule):
self,
in_features: int,
out_features: int,
split_sizes: List[int],
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@ -647,13 +677,18 @@ class FusedLinear1D_Col(ParallelModule):
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
self.split_sizes = split_sizes
self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == out_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -677,10 +712,10 @@ class FusedLinear1D_Col(ParallelModule):
self.weight = weight
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
@ -706,7 +741,11 @@ class FusedLinear1D_Col(ParallelModule):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs
module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]],
split_sizes: List[int],
*args,
**kwargs,
) -> ParallelModule:
r"""
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
@ -714,7 +753,7 @@ class FusedLinear1D_Col(ParallelModule):
Args:
module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
@ -737,25 +776,11 @@ class FusedLinear1D_Col(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
n_fused=n_fused,
split_sizes=split_sizes,
*args,
**kwargs,
)
# # TODO: copy the sharded weights
# with torch.no_grad():
# sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
# n_fused=n_fused,
# process_group=process_group,
# is_transposed=False)
# linear_1d.weight.data.copy_(sharded_weight.data)
# if bias:
# sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
# n_fused=n_fused,
# process_group=process_group,
# is_transposed=False)
# linear_1d.bias.data.copy_(sharded_bias.data)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@ -772,19 +797,29 @@ class FusedLinear1D_Col(ParallelModule):
input_.shape, self.weight.shape, self.weight.shape[-1]
)
# Set up backprop all-reduce.
# input_parallel = reduce_backward(input_, self.process_group)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if is_share_sp_tp(self.seq_parallel_mode):
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel,
self.weight,
bias,
self.process_group,
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
else:
output = output_parallel
@ -792,3 +827,196 @@ class FusedLinear1D_Col(ParallelModule):
return output, self.bias
else:
return output
class FusedLinear1D_Row(ParallelModule):
r"""Linear layer with row parallelism
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
split_sizes: List[int],
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.split_sizes = split_sizes
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
assert (
sum(split_sizes) == in_features
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to in_features({in_features})."
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
# Initialize weight.
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], split_sizes: List[int], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
linear_1d = FusedLinear1D_Row(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
split_sizes=split_sizes,
**kwargs,
)
return linear_1d
@torch.no_grad()
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
if self.process_group is None:
src_rank = 0
else:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
origin_device = self.bias.device
bias = self.bias.cuda()
dist.broadcast(bias, src=src_rank, group=self.process_group)
bias = bias.to(origin_device)
self.bias.copy_(bias)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
input_ = input_
else:
assert (
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
)
input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
if is_share_sp_tp(self.seq_parallel_mode):
output = linear_reducescatter_forward_gather_backward(
input_,
self.weight,
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias

@ -295,8 +295,8 @@ def split_batch_zigzag(
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
in the causal setting will result in the preceding ranks having much less workload.
Split the input sequence batch . Naively spliting the attention mask in the causal setting
will result in the preceding ranks having much less workload.
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
@ -346,40 +346,42 @@ def split_varlen_zigzag(
cu_seqlens: torch.Tensor,
sp_group: ProcessGroup,
max_seqlen: int = 0,
is_2d: bool = False,
is_batched_seq: bool = False,
is_label: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
For each tensor in batch, return packed sequences if is_2d is False;
else return a padded batch of sequences.
"""Split a packed seq/batch of padded sequences in a Zigzag fashion.
Different from split_batch_zigzag, inputs here have variable sequence lengths.
Args:
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq,
where T is the total number of tokens.
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
sp_group (ProcessGroup): The process group for sequence parallelism.
max_seqlen (int): The maximum sequence length in the batch before splitting.
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len.
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
Returns:
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
or (B, max_seqlen // sp_size, ...) if is_2d
batch (List[torch.Tensor]): Packed sequences of shape (T, ..)
or (B, max_seqlen // sp_size, ...) if is_batched_seq
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if sp_size == 1:
return batch
if is_2d:
if is_batched_seq:
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
if isinstance(batch, torch.Tensor):
batch = [batch]
# seq: (B, Sq, h, n)
# seq = seq[:, :rank * (seqlen // sp_size), ...]
for i, packed_seq in enumerate(batch):
device = packed_seq.device
dtype = packed_seq.dtype
if is_2d:
if is_batched_seq:
assert max_seqlen % (sp_size * 2) == 0
# Recreate a padded tensor with the new max seqlen
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
@ -398,7 +400,7 @@ def split_varlen_zigzag(
seqlen % (2 * sp_size) == 0
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
if is_2d:
if is_batched_seq:
seq = packed_seq[j][:seqlen]
if is_label:
# Shift one position to the right for next token prediction
@ -415,7 +417,7 @@ def split_varlen_zigzag(
seq = seq.chunk(sp_size * 2)
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
if is_2d:
if is_batched_seq:
batch[i] = local_seq.contiguous()
else:
batch[i] = torch.cat(local_seq, dim=0)

@ -857,17 +857,17 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
dropout_p = self.attn_dropout.p if self.training else 0.0
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query,
key,
value,
sp_group,
sp_axis=shard_config.sp_axis,
**attention_mask,
dropout_p=dropout_p,
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)

@ -271,6 +271,7 @@ class LlamaPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**kwargs,
):
r"""
Args:
@ -568,9 +569,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
query_states,
key_states,
value_states,
sp_group,
sp_axis=shard_config.sp_axis,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)
elif shard_config.enable_flash_attention:

@ -73,7 +73,6 @@ class BertPolicy(Policy):
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
@ -97,7 +96,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -106,7 +104,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -115,7 +112,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -140,7 +136,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},

@ -71,7 +71,7 @@ class BlipPolicy(Policy):
suffix="self_attn.qkv",
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication,
},
),

@ -57,7 +57,6 @@ class BloomPolicy(Policy):
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
@ -78,7 +77,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -99,7 +97,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),

@ -67,7 +67,6 @@ class ChatGLMPolicy(Policy):
f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather"]
if sp_mode == "all_to_all":
@ -127,7 +126,6 @@ class ChatGLMPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),

@ -65,7 +65,6 @@ class GPT2Policy(Policy):
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
)
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention
if self.shard_config.enable_tensor_parallelism:
@ -92,9 +91,8 @@ class GPT2Policy(Policy):
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
"split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -107,9 +105,8 @@ class GPT2Policy(Policy):
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},

@ -51,7 +51,6 @@ class GPTJPolicy(Policy):
self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -76,7 +75,6 @@ class GPTJPolicy(Policy):
suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -84,7 +82,6 @@ class GPTJPolicy(Policy):
suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@ -92,7 +89,6 @@ class GPTJPolicy(Policy):
suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),

@ -42,7 +42,7 @@ class SamPolicy(Policy):
suffix="attn.qkv",
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication,
},
),

@ -26,7 +26,6 @@ class ShardConfig:
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
@ -44,13 +43,14 @@ class ShardConfig:
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
sequence_parallelism_mode: str = None
enable_sequence_overlap: bool = False
parallel_output: bool = True
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention
sp_axis: Optional[int] = None
pg_mesh: Optional[int] = None
inner_ring_size: Optional[int] = None
# for moe related
moe_dp_group: Optional[ProcessGroup] = None
@ -84,24 +84,12 @@ class ShardConfig:
assert (
self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
elif self.sequence_parallelism_mode in ["all_to_all"]:
# assert (
# not self.enable_tensor_parallelism
# ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
if self.enable_sequence_overlap:
self.enable_sequence_overlap = False
warnings.warn(
f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}"
)
else:
if self.sequence_parallelism_mode:
self.sequence_parallelism_mode = None
warnings.warn(
f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False"
)
assert (
not self.enable_sequence_overlap
), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True"
# get the tensor parallel size
if not self.enable_tensor_parallelism:
@ -134,4 +122,3 @@ class ShardConfig:
# This can cause non-in-place param sharding when used without ZeRO.
# It may also slow down training when seq len is small. Plz enable manually.
# self.enable_sequence_parallelism = True
# self.enable_sequence_overlap = True

@ -5,6 +5,7 @@ from .common import (
ensure_path_exists,
free_storage,
get_current_device,
get_non_persistent_buffers_set,
is_ddp_ignored,
set_seed,
)
@ -25,4 +26,5 @@ __all__ = [
"set_seed",
"get_current_device",
"is_ddp_ignored",
"get_non_persistent_buffers_set",
]

@ -5,10 +5,11 @@ import os
import random
from contextlib import contextmanager
from pathlib import Path
from typing import Callable
from typing import Callable, Optional, Set
import numpy as np
import torch
import torch.nn as nn
from colossalai.accelerator import get_accelerator
@ -76,3 +77,34 @@ def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def get_non_persistent_buffers_set(
module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""
if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
)
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
child_non_persistent_set = get_non_persistent_buffers_set(
sub_module, memo, submodule_prefix, remove_duplicate
)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set

@ -0,0 +1,64 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json
from dataclasses import asdict, dataclass
from typing import Dict, List, Tuple
import torch
from safetensors.torch import _TYPES
try:
from tensornvme.async_file_io import AsyncFileWriter
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
_TYPES_INV = {v: k for k, v in _TYPES.items()}
@dataclass
class TensorInfo:
dtype: str
shape: List[int]
data_offsets: Tuple[int, int]
@dataclass
class PreparedData:
n: int
header_bytes: bytes
offset: int
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]:
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
tensors = []
metadata = {}
offset = 0
for name, tensor in sorted_data:
n = tensor.numel() * tensor.element_size()
tensor_info = TensorInfo(
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)
)
offset += n
metadata[name] = asdict(tensor_info)
tensors.append(tensor)
metadata_buf = json.dumps(metadata).encode("utf-8")
extra = (8 - len(metadata_buf) % 8) % 8
metadata_buf += b" " * extra
n = len(metadata_buf)
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
prepared_data, tensors = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)
for tensor in tensors:
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)

@ -35,7 +35,7 @@ from colossalai.tensor.padded_tensor import (
to_unpadded_tensor,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
from colossalai.utils import _cast_float, free_storage, get_non_persistent_buffers_set, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
@ -187,7 +187,7 @@ class GeminiDDP(ModelWrapper):
pin_memory=pin_memory,
)
super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._non_persistent_buffers_set = get_non_persistent_buffers_set(module)
self._cast_buffers()
# register grad hook
@ -257,36 +257,6 @@ class GeminiDDP(ModelWrapper):
for p in params_to_ignore:
p._ddp_to_ignore = True
def _get_non_persistent_buffers_set(
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""
if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
)
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
child_non_persistent_set = self._get_non_persistent_buffers_set(
sub_module, memo, submodule_prefix, remove_duplicate
)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set
def _post_forward(self):
"""This function is only triggered for inference."""
access_list = list(self.chunk_manager.accessed_chunks)

@ -1,10 +1,5 @@
import torch.nn
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float
@ -27,6 +22,12 @@ class RuntimeMemTracer:
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
super().__init__()
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
self.module = module
self.dtype = dtype
self._gradstat = GradMemStats()

@ -8,7 +8,6 @@ import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager
@ -172,6 +171,8 @@ class AutoPlacementPolicy(PlacementPolicy):
Returns:
int: the volume of memory that is evicted
"""
from colossalai.legacy.utils.memory import colo_device_memory_capacity
start = time()
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]

@ -25,15 +25,13 @@
</div>
## 新闻
* [2024/10] [How to build a low-cost Sora-like app? Solutions for you](https://company.hpc-ai.com/blog/how-to-build-a-low-cost-sora-like-app-solutions-for-you)
* [2024/09] [Singapore Startup HPC-AI Tech Secures 50 Million USD in Series A Funding to Build the Video Generation AI Model and GPU Platform](https://company.hpc-ai.com/blog/singapore-startup-hpc-ai-tech-secures-50-million-usd-in-series-a-funding-to-build-the-video-generation-ai-model-and-gpu-platform)
* [2024/09] [Reducing AI Large Model Training Costs by 30% Requires Just a Single Line of Code From FP8 Mixed Precision Training Upgrades](https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades)
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
* [2024/03] [Open-SoraSora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
## 目录
<ul>

@ -16,7 +16,7 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan)
AMP stands for automatic mixed precision training.
In Colossal-AI, we have incorporated different implementations of mixed precision training:
1. torch.cuda.amp
1. torch.amp
2. apex.amp
3. naive amp

@ -16,7 +16,7 @@
AMP 代表自动混合精度训练。
在 Colossal-AI 中, 我们结合了混合精度训练的不同实现:
1. torch.cuda.amp
1. torch.amp
2. apex.amp
3. naive amp

@ -163,6 +163,8 @@ def main():
enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@ -177,6 +179,8 @@ def main():
enable_flash_attention=args.xformers,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "fsdp":
if use_empty_init:
@ -188,6 +192,7 @@ def main():
),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@ -209,6 +214,7 @@ def main():
cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@ -219,6 +225,7 @@ def main():
),
cpu_offload=CPUOffload(offload_params=True),
fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "3d":
if args.pp_style == "zbv":

@ -79,7 +79,7 @@ class _CppExtension(_Extension):
# check if the kernel has been built
compiled_before = False
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
kernel_file_path = build_directory.joinpath(f"{self.name}.so")
if kernel_file_path.exists():
compiled_before = True

@ -74,7 +74,7 @@ class _CudaExtension(_CppExtension):
# check if the kernel has been built
compiled_before = False
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
kernel_file_path = build_directory.joinpath(f"{self.name}.so")
if kernel_file_path.exists():
compiled_before = True

@ -41,22 +41,7 @@ class Conv1D(nn.Module):
return x
def rearrange(tensor: torch.Tensor, dim: int):
tensor = tensor.clone()
world_size = 2
order = torch.arange(world_size * 3)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
return rearanged_tensor
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
@ -66,8 +51,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
process_group=None,
gather_output=True,
seq_parallel_mode=seq_parallel_mode,
n_fused=3,
overlap=overlap,
split_sizes=[64] * 3,
)
assert linear.weight.shape == torch.Size([48, 192])
@ -88,13 +72,13 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
)
gather_out = linear_conv_col(x_for_shard)
assert_close(rearrange(out, -1), gather_out)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [64] * 3, None, True)
assert_close(target_grad, linear_conv_col.weight.grad)
@ -136,9 +120,8 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
@parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", ["split_gather", None])
@parameterize("overlap", [True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)

@ -2,13 +2,12 @@ import os
from contextlib import nullcontext
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -16,93 +15,55 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
self.weight = nn.Parameter(torch.empty(nx, nf))
self.bias = nn.Parameter(torch.zeros(nf))
nn.init.normal_(self.weight, std=0.02)
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
def rearrange(tensor: torch.Tensor, dim: int):
tensor = tensor.clone()
world_size = 2
order = torch.arange(world_size * 3)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
return rearanged_tensor
@parameterize("lazy_init", [False, True])
def check_linear_conv_1d_col(lazy_init: bool):
def check_linear_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
linear = nn.Linear(8, 80).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, n_fused=3
linear_copy = nn.Linear(8, 80).cuda()
linear_col = FusedLinear1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, split_sizes=[32, 32, 16]
)
assert linear.weight.shape == torch.Size([48, 192])
assert linear.bias.shape == torch.Size([192])
assert linear_conv_col.weight.shape == torch.Size([48, 96])
assert linear_conv_col.bias.shape == torch.Size([96])
assert linear_copy.weight is linear_conv_col.weight
assert linear_copy.bias is linear_conv_col.bias
assert linear.weight.shape == torch.Size([80, 8])
assert linear.bias.shape == torch.Size([80])
assert linear_col.weight.shape == torch.Size([40, 8])
assert linear_col.bias.shape == torch.Size([40])
assert linear_copy.weight is linear_col.weight
assert linear_copy.bias is linear_col.bias
# ensure weights are reversibly loadable
linear_conv_col.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_conv_col.state_dict())
linear_col.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_col.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(4, 8).cuda()
out = linear(x)
gather_out = linear_conv_col(x)
assert_close(rearrange(out, 1), gather_out)
gather_out = linear_col(x)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
assert_close(target_grad, linear_conv_col.weight.grad)
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, False)
assert_close(target_grad, linear_col.weight.grad)
@parameterize("lazy_init", [False, True])
def check_linear_conv_1d_row(lazy_init: bool):
def check_linear_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
linear = nn.Linear(80, 8).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
linear_copy = nn.Linear(80, 8).cuda()
linear_row = FusedLinear1D_Row.from_native_module(
linear_copy, process_group=None, split_sizes=[32, 32, 16], parallel_input=False
)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192])
assert linear_row.bias.shape == torch.Size([192])
assert linear.weight.shape == torch.Size([8, 80])
assert linear_row.weight.shape == torch.Size([8, 40])
assert linear_row.bias.shape == torch.Size([8])
assert linear_copy.weight is linear_row.weight
assert linear_copy.bias is linear_row.bias
@ -111,7 +72,7 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear.load_state_dict(linear_row.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(4, 80).cuda()
out = linear(x)
gather_out = linear_row(x)
assert_close(out, gather_out)
@ -120,17 +81,51 @@ def check_linear_conv_1d_row(lazy_init: bool):
out.sum().backward()
gather_out.sum().backward()
rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, True)
assert_close(target_grad, linear_row.weight.grad)
@parameterize("lazy_init", [False, True])
def check_linear_1d_col_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear1 = nn.Linear(8, 80).cuda()
linear2 = nn.Linear(80, 8).cuda()
with ctx:
linear1_copy = nn.Linear(8, 80).cuda()
linear2_copy = nn.Linear(80, 8).cuda()
linear_col = FusedLinear1D_Col.from_native_module(linear1_copy, process_group=None, split_sizes=[32, 32, 16])
linear_row = FusedLinear1D_Row.from_native_module(
linear2_copy,
process_group=None,
split_sizes=[32, 32, 16],
)
# ensure weights are reversibly loadable
linear_col.load_state_dict(linear1.state_dict())
linear_row.load_state_dict(linear2.state_dict())
# check computation correctness
x = torch.rand(4, 8).cuda()
target_out = linear2(linear1(x))
out = linear_row(linear_col(x))
assert_close(out, target_out)
# check backward correctness
target_out.sum().backward()
out.sum().backward()
target_grad1 = split_fused_qkv_in_gpt2_style(linear1.weight.grad, [32, 32, 16], None, False)
assert_close(target_grad1, linear_col.weight.grad)
target_grad2 = split_fused_qkv_in_gpt2_style(linear2.weight.grad, [32, 32, 16], None, True)
assert_close(target_grad2, linear_row.weight.grad)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# test for linear conv
check_linear_conv_1d_col()
check_linear_conv_1d_row()
check_linear_1d_col()
check_linear_1d_row()
check_linear_1d_col_row()
@rerun_if_address_is_in_use()

@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
@ -17,11 +18,14 @@ from colossalai.utils import get_current_device
@parameterize("nheads", [5])
@parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_ring_attn(seq_len, bs, nheads, d, dtype):
def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
torch.cuda.manual_seed(2)
device = get_current_device()
sp_group = dist.group.WORLD
dp_size, pp_size, tp_size = 1, 1, 1
sp_size = dist.get_world_size()
sp_axis = 2
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
# Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
@ -40,11 +44,11 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
q,
k,
v,
sp_group,
sp_axis,
AttnMaskType.CAUSAL,
return_softmax=True,
inner_ring_size=max(2, sp_size // 2),
# inner_ring_size=4
inner_ring_size=inner_ring_size,
pg_mesh=pg_mesh,
)
ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func(
@ -83,6 +87,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
device = get_current_device()
sp_group = dist.group.WORLD
sp_size = dist.get_world_size()
sp_axis = 2
atol = rtol = 7e-3
torch.cuda.manual_seed(2)
# Prepare varlen attention mask
@ -123,10 +128,11 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
q_ring,
k_ring,
v_ring,
sp_group,
sp_axis,
**mask_info,
pad_output=False,
return_softmax=True,
pg_mesh=ProcessGroupMesh(1, 1, sp_size, 1),
# deterministic=True
)
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
@ -161,12 +167,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
def launch_single_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_packed_seq()
check_ring_attn()
check_ring_attn(inner_ring_size=None)
def launch_double_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_ring_attn()
check_ring_attn(inner_ring_size=2)
@rerun_if_address_is_in_use()

@ -1 +1 @@
0.4.4
0.4.5

Loading…
Cancel
Save