mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #6107 from duanjunwen/dev/zero_bubble
[Zerobubble] Merge Main.feature/zerobubble
commit
37b23e32b1
|
@ -15,21 +15,21 @@ repos:
|
||||||
args: ["--profile", "black"] # avoid conflict with black
|
args: ["--profile", "black"] # avoid conflict with black
|
||||||
|
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 24.8.0
|
rev: 24.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
name: black formatter
|
name: black formatter
|
||||||
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v18.1.8
|
rev: v19.1.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
name: clang formatter
|
name: clang formatter
|
||||||
types_or: [c++, c]
|
types_or: [c++, c]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.6.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-merge-conflict
|
- id: check-merge-conflict
|
||||||
|
|
30
README.md
30
README.md
|
@ -25,16 +25,36 @@
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
## GPU Cloud HPC-AI.COM Coming!!
|
||||||
|
|
||||||
|
For a limited time, you can access an H100 Server for just $1! This is your chance to leverage premium GPU power at an unbeatable price.
|
||||||
|
Plus, when you refer a friend, you’ll receive 20% cashback or compute credits equal to 100% of their top-up!
|
||||||
|
|
||||||
|
Our platform offers on-demand premium compute, ensuring safe, permanent data storage even after stopping your instance.
|
||||||
|
Don’t miss this incredible opportunity to accelerate your AI projects!
|
||||||
|
|
||||||
|
Unlock premium GPUs and register now at [HPC-AI.COM](https://hpc-ai.com) to receive $10!
|
||||||
|
|
||||||
|
Special Bonuses:
|
||||||
|
|
||||||
|
* Top up $1,000 and receive 300 credits
|
||||||
|
* Top up $500 and receive 100 credits
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://youtu.be/ilMQpU71ddI?si=J4JSPzZ03ycYmlki">
|
||||||
|
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/HPCAICOM241010.jpg" width="700" />
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
## Latest News
|
## Latest News
|
||||||
|
* [2024/10] [How to build a low-cost Sora-like app? Solutions for you](https://company.hpc-ai.com/blog/how-to-build-a-low-cost-sora-like-app-solutions-for-you)
|
||||||
|
* [2024/09] [Singapore Startup HPC-AI Tech Secures 50 Million USD in Series A Funding to Build the Video Generation AI Model and GPU Platform](https://company.hpc-ai.com/blog/singapore-startup-hpc-ai-tech-secures-50-million-usd-in-series-a-funding-to-build-the-video-generation-ai-model-and-gpu-platform)
|
||||||
|
* [2024/09] [Reducing AI Large Model Training Costs by 30% Requires Just a Single Line of Code From FP8 Mixed Precision Training Upgrades](https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades)
|
||||||
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||||
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
||||||
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
||||||
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
||||||
* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
|
|
||||||
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
|
|
||||||
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
|
|
||||||
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
|
|
||||||
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
|
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
<ul>
|
<ul>
|
||||||
|
|
|
@ -27,11 +27,11 @@
|
||||||
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
|
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
|
||||||
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
|
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
|
||||||
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
||||||
|
- [O1 Journey](#o1-journey)
|
||||||
|
- [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
|
||||||
- [FAQ](#faq)
|
- [FAQ](#faq)
|
||||||
- [How to save/load checkpoint](#faq)
|
- [How to save/load checkpoint](#faq)
|
||||||
- [How to train with limited resources](#faq)
|
- [How to train with limited resources](#faq)
|
||||||
- [The Plan](#the-plan)
|
|
||||||
- [Real-time progress](#real-time-progress)
|
|
||||||
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
|
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
|
||||||
- [Quick Preview](#quick-preview)
|
- [Quick Preview](#quick-preview)
|
||||||
- [Authors](#authors)
|
- [Authors](#authors)
|
||||||
|
@ -272,7 +272,7 @@ Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pd
|
||||||
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
|
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
|
||||||
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
|
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
|
||||||
|
|
||||||
### Inference Quantization and Serving - After Training
|
## Inference Quantization and Serving - After Training
|
||||||
|
|
||||||
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
||||||
|
|
||||||
|
@ -281,6 +281,21 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen
|
||||||
Online inference server scripts can help you deploy your own services.
|
Online inference server scripts can help you deploy your own services.
|
||||||
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
|
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
|
||||||
|
|
||||||
|
## O1 Journey
|
||||||
|
### Inference with Self-refined MCTS
|
||||||
|
We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search.
|
||||||
|
To run inference with MCTS, simply use the following script.
|
||||||
|
```python
|
||||||
|
from coati.reasoner.guided_search.mcts import MCTS
|
||||||
|
from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG
|
||||||
|
|
||||||
|
problem = "How Many R in 'Strawberry'"
|
||||||
|
|
||||||
|
search_tree = MCTS(problem=problem, max_simulations=8, cfg=Qwen32B_prompt_CFG)
|
||||||
|
answer = search_tree.simulate()
|
||||||
|
print(answer)
|
||||||
|
```
|
||||||
|
|
||||||
## Coati7B examples
|
## Coati7B examples
|
||||||
|
|
||||||
### Generation
|
### Generation
|
||||||
|
|
|
@ -153,10 +153,11 @@ class DpoLoss(nn.Module):
|
||||||
else:
|
else:
|
||||||
# If no reference model is provided
|
# If no reference model is provided
|
||||||
ref_logratios = 0.0
|
ref_logratios = 0.0
|
||||||
|
|
||||||
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
|
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
|
||||||
logits = pi_logratios - ref_logratios - self.gamma / self.beta
|
logits = pi_logratios - ref_logratios - self.gamma / self.beta
|
||||||
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
|
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
|
||||||
|
loss = losses.mean()
|
||||||
# Calculate rewards for logging
|
# Calculate rewards for logging
|
||||||
if logprob_ref_chosen is not None:
|
if logprob_ref_chosen is not None:
|
||||||
chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
|
chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
|
||||||
|
@ -167,7 +168,7 @@ class DpoLoss(nn.Module):
|
||||||
else:
|
else:
|
||||||
rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()
|
rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()
|
||||||
|
|
||||||
return losses, chosen_rewards, rejected_rewards
|
return loss, chosen_rewards, rejected_rewards
|
||||||
|
|
||||||
|
|
||||||
class LogSigLoss(nn.Module):
|
class LogSigLoss(nn.Module):
|
||||||
|
|
|
@ -50,8 +50,8 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
|
||||||
torch.Tensor: The log probabilities corresponding to the labels.
|
torch.Tensor: The log probabilities corresponding to the labels.
|
||||||
"""
|
"""
|
||||||
log_probs = F.log_softmax(logits, dim=-1)
|
log_probs = F.log_softmax(logits, dim=-1)
|
||||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||||
return log_probs_labels.squeeze(-1)
|
return per_label_logps.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
import openai
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
API_KEY = "Dummy API Key"
|
||||||
|
|
||||||
|
|
||||||
|
def get_client(base_url: str | None = None) -> openai.Client:
|
||||||
|
return openai.Client(api_key=API_KEY, base_url=base_url)
|
||||||
|
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
model: str,
|
||||||
|
base_url: str | None = None,
|
||||||
|
temperature: float = 0.8,
|
||||||
|
**kwargs,
|
||||||
|
) -> ChatCompletion:
|
||||||
|
client = get_client(base_url)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return response
|
|
@ -0,0 +1,250 @@
|
||||||
|
"""
|
||||||
|
Implementation of MCTS + Self-refine algorithm.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
|
||||||
|
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
|
||||||
|
2. https://github.com/BrendanGraham14/mcts-llm/
|
||||||
|
3. https://github.com/trotsky1997/MathBlackBox/
|
||||||
|
4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
from coati.reasoner.guided_search.llm import chat_completion
|
||||||
|
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class MCTSNode(BaseModel):
|
||||||
|
"""
|
||||||
|
Node for MCTS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
answer: str
|
||||||
|
parent: MCTSNode = None
|
||||||
|
children: list[MCTSNode] = []
|
||||||
|
num_visits: int = 0
|
||||||
|
Q: int = 0
|
||||||
|
rewards: list[int] = []
|
||||||
|
|
||||||
|
def expand_node(self, node) -> None:
|
||||||
|
self.children.append(node)
|
||||||
|
|
||||||
|
def add_reward(self, reward: int) -> None:
|
||||||
|
self.rewards.append(reward)
|
||||||
|
self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2
|
||||||
|
|
||||||
|
|
||||||
|
class MCTS(BaseModel):
|
||||||
|
"""
|
||||||
|
Simulation of MCTS process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
problem: str
|
||||||
|
max_simulations: int
|
||||||
|
cfg: PromptCFG
|
||||||
|
C: float = 1.4
|
||||||
|
max_children: int = 2
|
||||||
|
epsilon: float = 1e-5
|
||||||
|
root: MCTSNode = None
|
||||||
|
|
||||||
|
def initialization(self):
|
||||||
|
"""
|
||||||
|
Root Initiation.
|
||||||
|
"""
|
||||||
|
# Dummy answer as root.
|
||||||
|
base_answer = self.sample_base_answer()
|
||||||
|
self.root = MCTSNode(answer=base_answer)
|
||||||
|
self.self_evaluate(self.root)
|
||||||
|
|
||||||
|
def is_fully_expanded(self, node: MCTSNode):
|
||||||
|
return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children)
|
||||||
|
|
||||||
|
def select_node(self) -> MCTSNode:
|
||||||
|
"""
|
||||||
|
Select next node to explore.
|
||||||
|
"""
|
||||||
|
candidates: list[MCTSNode] = []
|
||||||
|
to_explore = deque([self.root])
|
||||||
|
|
||||||
|
while to_explore:
|
||||||
|
current_node = to_explore.popleft()
|
||||||
|
if not self.is_fully_expanded(current_node):
|
||||||
|
candidates.append(current_node)
|
||||||
|
to_explore.extend(current_node.children)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return self.root
|
||||||
|
|
||||||
|
return max(candidates, key=self.compute_uct)
|
||||||
|
|
||||||
|
def self_evaluate(self, node: MCTSNode):
|
||||||
|
"""
|
||||||
|
Sample reward of the answer.
|
||||||
|
"""
|
||||||
|
reward = self.sample_reward(node)
|
||||||
|
node.add_reward(reward)
|
||||||
|
|
||||||
|
def back_propagation(self, node: MCTSNode):
|
||||||
|
"""
|
||||||
|
Back propagate the value of the refined answer.
|
||||||
|
"""
|
||||||
|
parent = node.parent
|
||||||
|
while parent:
|
||||||
|
best_child_Q = max(child.Q for child in parent.children)
|
||||||
|
parent.Q = (parent.Q + best_child_Q) / 2
|
||||||
|
parent.num_visits += 1
|
||||||
|
parent = parent.parent
|
||||||
|
|
||||||
|
def compute_uct(self, node: MCTSNode):
|
||||||
|
"""
|
||||||
|
Compute UCT.
|
||||||
|
"""
|
||||||
|
if node.parent is None:
|
||||||
|
return -100
|
||||||
|
return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon))
|
||||||
|
|
||||||
|
def simulate(self):
|
||||||
|
self.initialization()
|
||||||
|
for _ in tqdm.tqdm(range(self.max_simulations)):
|
||||||
|
node = self.select_node()
|
||||||
|
child = self.self_refine(node)
|
||||||
|
node.expand_node(child)
|
||||||
|
self.self_evaluate(child)
|
||||||
|
self.back_propagation(child)
|
||||||
|
|
||||||
|
return self.get_best_answer()
|
||||||
|
|
||||||
|
def get_best_answer(self):
|
||||||
|
to_visit = deque([self.root])
|
||||||
|
best_node = self.root
|
||||||
|
|
||||||
|
while to_visit:
|
||||||
|
current_node = to_visit.popleft()
|
||||||
|
if current_node.Q > best_node.Q:
|
||||||
|
best_node = current_node
|
||||||
|
to_visit.extend(current_node.children)
|
||||||
|
|
||||||
|
return best_node.answer
|
||||||
|
|
||||||
|
def self_refine(self, node: MCTSNode):
|
||||||
|
"""
|
||||||
|
Refine node.
|
||||||
|
"""
|
||||||
|
critique_response = chat_completion(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.cfg.critic_system_prompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "\n\n".join(
|
||||||
|
[
|
||||||
|
f"<problem>\n{self.problem}\n</problem>",
|
||||||
|
f"<current_answer>\n{node.answer}\n</current_answer>",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
model=self.cfg.model,
|
||||||
|
base_url=self.cfg.base_url,
|
||||||
|
max_tokens=self.cfg.max_tokens,
|
||||||
|
)
|
||||||
|
critique = critique_response.choices[0].message.content
|
||||||
|
assert critique is not None
|
||||||
|
refined_answer_response = chat_completion(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.cfg.refine_system_prompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "\n\n".join(
|
||||||
|
[
|
||||||
|
f"<problem>\n{self.problem}\n</problem>",
|
||||||
|
f"<current_answer>\n{node.answer}\n</current_answer>",
|
||||||
|
f"<critique>\n{critique}\n</critique>",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
model=self.cfg.model,
|
||||||
|
base_url=self.cfg.base_url,
|
||||||
|
max_tokens=self.cfg.max_tokens,
|
||||||
|
)
|
||||||
|
refined_answer = refined_answer_response.choices[0].message.content
|
||||||
|
assert refined_answer is not None
|
||||||
|
|
||||||
|
return MCTSNode(answer=refined_answer, parent=node)
|
||||||
|
|
||||||
|
def sample_base_answer(self):
|
||||||
|
response = chat_completion(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"<problem>\n {self.problem} \n</problem> \nLet's think step by step",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
model=self.cfg.model,
|
||||||
|
base_url=self.cfg.base_url,
|
||||||
|
max_tokens=self.cfg.max_tokens,
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def sample_reward(self, node: MCTSNode):
|
||||||
|
"""
|
||||||
|
Calculate reward.
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.cfg.evaluate_system_prompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "\n\n".join(
|
||||||
|
[
|
||||||
|
f"<problem>\n{self.problem}\n</problem>",
|
||||||
|
f"<answer>\n{node.answer}\n</answer>",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
for attempt in range(3):
|
||||||
|
try:
|
||||||
|
response = chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
model=self.cfg.model,
|
||||||
|
base_url=self.cfg.base_url,
|
||||||
|
max_tokens=self.cfg.max_tokens,
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
return int(response.choices[0].message.content)
|
||||||
|
except ValueError:
|
||||||
|
messages.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Failed to parse reward as an integer.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if attempt == 2:
|
||||||
|
raise
|
|
@ -0,0 +1,10 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class PromptCFG(BaseModel):
|
||||||
|
model: str
|
||||||
|
base_url: str
|
||||||
|
max_tokens: int = 4096
|
||||||
|
critic_system_prompt: str
|
||||||
|
refine_system_prompt: str
|
||||||
|
evaluate_system_prompt: str
|
|
@ -0,0 +1,20 @@
|
||||||
|
"""
|
||||||
|
Prompts for Qwen Series.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
|
||||||
|
|
||||||
|
Qwen32B_prompt_CFG = PromptCFG(
|
||||||
|
base_url="http://0.0.0.0:8008/v1",
|
||||||
|
model="Qwen2.5-32B-Instruct",
|
||||||
|
critic_system_prompt="Provide a detailed and constructive critique to improve the answer. "
|
||||||
|
"Highlight specific areas that need refinement or correction.",
|
||||||
|
refine_system_prompt="""# Instruction
|
||||||
|
Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer].
|
||||||
|
""",
|
||||||
|
evaluate_system_prompt=(
|
||||||
|
"Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. "
|
||||||
|
"Do not give a full score above 95. Make sure the reward score is an integer. "
|
||||||
|
"Return *ONLY* the score."
|
||||||
|
),
|
||||||
|
)
|
|
@ -6,6 +6,7 @@ import os
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from coati.models.loss import DpoLoss
|
from coati.models.loss import DpoLoss
|
||||||
from coati.models.utils import calc_masked_log_probs
|
from coati.models.utils import calc_masked_log_probs
|
||||||
from coati.trainer.utils import all_reduce_mean
|
from coati.trainer.utils import all_reduce_mean
|
||||||
|
@ -13,10 +14,11 @@ from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import trange
|
from tqdm import tqdm, trange
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from colossalai.booster import Booster, Plugin
|
from colossalai.booster import Booster, Plugin
|
||||||
|
from colossalai.booster.plugin import HybridParallelPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
@ -96,18 +98,25 @@ class DPOTrainer(SLTrainer):
|
||||||
self.train_dataloader = train_preference_dataloader
|
self.train_dataloader = train_preference_dataloader
|
||||||
self.eval_dataloader = eval_preference_dataloader
|
self.eval_dataloader = eval_preference_dataloader
|
||||||
self.writer = None
|
self.writer = None
|
||||||
if use_wandb and is_rank_0():
|
|
||||||
|
init_criterion = (
|
||||||
|
dist.get_rank() == dist.get_world_size() - 1
|
||||||
|
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1
|
||||||
|
else is_rank_0()
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_wandb and init_criterion:
|
||||||
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True)
|
self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True)
|
||||||
if log_dir is not None and is_rank_0():
|
if log_dir is not None and init_criterion:
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
log_dir = os.path.join(log_dir, "dpo")
|
log_dir = os.path.join(log_dir, "DPO")
|
||||||
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||||
self.writer = SummaryWriter(log_dir=log_dir)
|
self.writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
|
@ -117,166 +126,147 @@ class DPOTrainer(SLTrainer):
|
||||||
epoch int: the number of current epoch
|
epoch int: the number of current epoch
|
||||||
"""
|
"""
|
||||||
self.model.train()
|
self.model.train()
|
||||||
self.accumulative_meter.reset()
|
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||||
step_bar = trange(
|
step_bar = tqdm(
|
||||||
len(self.train_dataloader) // self.accumulation_steps,
|
range(len(self.train_dataloader)),
|
||||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
desc="Step",
|
||||||
disable=not is_rank_0(),
|
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||||
)
|
|
||||||
for i, batch in enumerate(self.train_dataloader):
|
|
||||||
batch = to_device(batch, self.device)
|
|
||||||
(
|
|
||||||
chosen_input_ids,
|
|
||||||
chosen_attention_mask,
|
|
||||||
chosen_loss_mask,
|
|
||||||
reject_input_ids,
|
|
||||||
reject_attention_mask,
|
|
||||||
reject_loss_mask,
|
|
||||||
) = (
|
|
||||||
batch["chosen_input_ids"],
|
|
||||||
batch["chosen_attention_mask"],
|
|
||||||
batch["chosen_loss_mask"],
|
|
||||||
batch["reject_input_ids"],
|
|
||||||
batch["reject_attention_mask"],
|
|
||||||
batch["reject_loss_mask"],
|
|
||||||
)
|
)
|
||||||
if not self.apply_loss_mask:
|
for i, batch in enumerate(self.train_dataloader):
|
||||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
batch = to_device(batch, self.device)
|
||||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
(
|
||||||
|
chosen_input_ids,
|
||||||
|
chosen_attention_mask,
|
||||||
|
chosen_loss_mask,
|
||||||
|
reject_input_ids,
|
||||||
|
reject_attention_mask,
|
||||||
|
reject_loss_mask,
|
||||||
|
) = (
|
||||||
|
batch["chosen_input_ids"],
|
||||||
|
batch["chosen_attention_mask"],
|
||||||
|
batch["chosen_loss_mask"],
|
||||||
|
batch["reject_input_ids"],
|
||||||
|
batch["reject_attention_mask"],
|
||||||
|
batch["reject_loss_mask"],
|
||||||
|
)
|
||||||
|
batch_size = chosen_input_ids.size()[0]
|
||||||
|
# Calculate logits from reference model.
|
||||||
|
if self.ref_model is not None:
|
||||||
|
self.ref_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_all_logits = self.ref_model(
|
||||||
|
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
|
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
)["logits"]
|
||||||
|
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||||
|
ref_reject_logits = ref_all_logits[batch_size:]
|
||||||
|
logprob_ref_chosen = calc_masked_log_probs(
|
||||||
|
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
logprob_ref_reject = calc_masked_log_probs(
|
||||||
|
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprob_ref_chosen = None
|
||||||
|
logprob_ref_reject = None
|
||||||
|
|
||||||
batch_size = chosen_input_ids.size()[0]
|
# Merge chosen and reject
|
||||||
|
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
|
||||||
|
attention_mask = torch.stack(
|
||||||
|
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
|
||||||
|
)
|
||||||
|
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
|
||||||
|
logprob_ref = torch.stack([item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup])
|
||||||
|
|
||||||
actor_all_logits = self.model(
|
data_iter = iter(
|
||||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
[
|
||||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
{
|
||||||
)["logits"]
|
"input_ids": inputs_ids,
|
||||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
"attention_mask": attention_mask,
|
||||||
actor_reject_logits = actor_all_logits[batch_size:]
|
"loss_mask": loss_mask,
|
||||||
logprob_actor_chosen = calc_masked_log_probs(
|
"logprob_ref": logprob_ref,
|
||||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
}
|
||||||
)
|
]
|
||||||
|
)
|
||||||
|
rewards = []
|
||||||
|
|
||||||
logprob_actor_reject = calc_masked_log_probs(
|
def _criterion(outputs, inputs):
|
||||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||||
)
|
calc_masked_log_probs(
|
||||||
|
outputs["logits"][0::2],
|
||||||
if self.ref_model is not None:
|
inputs["input_ids"][0::2],
|
||||||
self.ref_model.eval()
|
inputs["loss_mask"][0::2][:, 1:],
|
||||||
with torch.no_grad():
|
self.length_normalization,
|
||||||
ref_all_logits = self.ref_model(
|
),
|
||||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
calc_masked_log_probs(
|
||||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
outputs["logits"][1::2],
|
||||||
)["logits"]
|
inputs["input_ids"][1::2],
|
||||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
inputs["loss_mask"][1::2][:, 1:],
|
||||||
ref_reject_logits = ref_all_logits[batch_size:]
|
self.length_normalization,
|
||||||
logprob_ref_chosen = calc_masked_log_probs(
|
),
|
||||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
|
||||||
|
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
|
||||||
|
inputs["loss_mask"][0::2][:, 1:],
|
||||||
|
inputs["loss_mask"][1::2][:, 1:],
|
||||||
)
|
)
|
||||||
logprob_ref_reject = calc_masked_log_probs(
|
rewards.append(chosen_rewards)
|
||||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
rewards.append(rejected_rewards)
|
||||||
)
|
return loss
|
||||||
else:
|
|
||||||
logprob_ref_chosen = None
|
|
||||||
logprob_ref_reject = None
|
|
||||||
|
|
||||||
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
outputs = self.booster.execute_pipeline(
|
||||||
logprob_actor_chosen,
|
data_iter,
|
||||||
logprob_actor_reject,
|
self.model,
|
||||||
logprob_ref_chosen if logprob_ref_chosen is not None else None,
|
criterion=_criterion,
|
||||||
logprob_ref_reject if logprob_ref_reject is not None else None,
|
optimizer=self.optimizer,
|
||||||
chosen_loss_mask[:, 1:],
|
return_loss=True,
|
||||||
reject_loss_mask[:, 1:],
|
)
|
||||||
)
|
loss = outputs["loss"]
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
|
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
|
||||||
|
global_loss = all_reduce_mean(loss, self.plugin)
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
step_bar.set_postfix(
|
||||||
|
{
|
||||||
|
"train/loss": global_loss.item(),
|
||||||
|
"train/lr": self.actor_scheduler.get_last_lr()[0],
|
||||||
|
"train/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
|
||||||
|
"train/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
step_bar.update()
|
||||||
|
self.accumulative_meter.add("loss", global_loss.item())
|
||||||
|
self.accumulative_meter.add("chosen_rewards", chosen_rewards.to(torch.float16).mean().item())
|
||||||
|
self.accumulative_meter.add(
|
||||||
|
"rejected_rewards", rejected_rewards.to(torch.float16).mean().item()
|
||||||
|
)
|
||||||
|
if self.writer is not None:
|
||||||
|
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), i)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), i
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/rejected_rewards",
|
||||||
|
self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
i,
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/margin",
|
||||||
|
self.accumulative_meter.get("chosen_rewards")
|
||||||
|
- self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
i,
|
||||||
|
)
|
||||||
|
|
||||||
# DPO Loss
|
|
||||||
loss = losses.mean()
|
|
||||||
|
|
||||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
|
||||||
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.actor_scheduler.step()
|
self.actor_scheduler.step()
|
||||||
|
else:
|
||||||
# sync
|
self.accumulative_meter.reset()
|
||||||
loss_mean = all_reduce_mean(tensor=loss)
|
step_bar = trange(
|
||||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
len(self.train_dataloader) // self.accumulation_steps,
|
||||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||||
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
disable=not is_rank_0(),
|
||||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
)
|
||||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
for i, batch in enumerate(self.train_dataloader):
|
||||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
|
||||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
|
||||||
|
|
||||||
if i % self.accumulation_steps == self.accumulation_steps - 1:
|
|
||||||
self.num_train_step += 1
|
|
||||||
step_bar.update()
|
|
||||||
# logging
|
|
||||||
if self.writer and is_rank_0():
|
|
||||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
|
||||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
|
||||||
self.writer.add_scalar(
|
|
||||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
|
||||||
)
|
|
||||||
self.writer.add_scalar(
|
|
||||||
"train/rejected_rewards",
|
|
||||||
self.accumulative_meter.get("rejected_rewards"),
|
|
||||||
self.num_train_step,
|
|
||||||
)
|
|
||||||
self.writer.add_scalar(
|
|
||||||
"train/margin",
|
|
||||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
|
||||||
self.num_train_step,
|
|
||||||
)
|
|
||||||
self.writer.add_scalar(
|
|
||||||
"train/accuracy",
|
|
||||||
self.accumulative_meter.get("accuracy"),
|
|
||||||
self.num_train_step,
|
|
||||||
)
|
|
||||||
self.accumulative_meter.reset()
|
|
||||||
|
|
||||||
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
|
|
||||||
# save checkpoint
|
|
||||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
|
||||||
save_checkpoint(
|
|
||||||
save_dir=self.save_dir,
|
|
||||||
booster=self.booster,
|
|
||||||
model=self.model,
|
|
||||||
optimizer=self.optimizer,
|
|
||||||
lr_scheduler=self.actor_scheduler,
|
|
||||||
epoch=epoch,
|
|
||||||
step=i + 1,
|
|
||||||
batch_size=batch_size,
|
|
||||||
coordinator=self.coordinator,
|
|
||||||
)
|
|
||||||
self.coordinator.print_on_master(
|
|
||||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
step_bar.close()
|
|
||||||
|
|
||||||
def _eval(self, epoch: int):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
epoch int: the number of current epoch
|
|
||||||
"""
|
|
||||||
if self.eval_dataloader is None:
|
|
||||||
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
|
|
||||||
return
|
|
||||||
self.model.eval()
|
|
||||||
self.ref_model.eval()
|
|
||||||
self.coordinator.print_on_master("\nStart evaluation...")
|
|
||||||
|
|
||||||
step_bar = trange(
|
|
||||||
len(self.eval_dataloader),
|
|
||||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
|
||||||
disable=not is_rank_0(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.accumulative_meter.reset()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for i, batch in enumerate(self.eval_dataloader):
|
|
||||||
batch = to_device(batch, self.device)
|
batch = to_device(batch, self.device)
|
||||||
(
|
(
|
||||||
chosen_input_ids,
|
chosen_input_ids,
|
||||||
|
@ -300,12 +290,11 @@ class DPOTrainer(SLTrainer):
|
||||||
batch_size = chosen_input_ids.size()[0]
|
batch_size = chosen_input_ids.size()[0]
|
||||||
|
|
||||||
actor_all_logits = self.model(
|
actor_all_logits = self.model(
|
||||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
)["logits"]
|
)["logits"]
|
||||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||||
actor_reject_logits = actor_all_logits[batch_size:]
|
actor_reject_logits = actor_all_logits[batch_size:]
|
||||||
|
|
||||||
logprob_actor_chosen = calc_masked_log_probs(
|
logprob_actor_chosen = calc_masked_log_probs(
|
||||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
)
|
)
|
||||||
|
@ -314,22 +303,26 @@ class DPOTrainer(SLTrainer):
|
||||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ref_model.eval()
|
if self.ref_model is not None:
|
||||||
|
self.ref_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_all_logits = self.ref_model(
|
||||||
|
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
|
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
)["logits"]
|
||||||
|
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||||
|
ref_reject_logits = ref_all_logits[batch_size:]
|
||||||
|
logprob_ref_chosen = calc_masked_log_probs(
|
||||||
|
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
logprob_ref_reject = calc_masked_log_probs(
|
||||||
|
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprob_ref_chosen = None
|
||||||
|
logprob_ref_reject = None
|
||||||
|
|
||||||
ref_all_logits = self.ref_model(
|
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
|
||||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
|
||||||
)["logits"]
|
|
||||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
|
||||||
ref_reject_logits = ref_all_logits[batch_size:]
|
|
||||||
logprob_ref_chosen = calc_masked_log_probs(
|
|
||||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
|
||||||
)
|
|
||||||
logprob_ref_reject = calc_masked_log_probs(
|
|
||||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
|
||||||
)
|
|
||||||
|
|
||||||
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
|
||||||
logprob_actor_chosen,
|
logprob_actor_chosen,
|
||||||
logprob_actor_reject,
|
logprob_actor_reject,
|
||||||
logprob_ref_chosen if logprob_ref_chosen is not None else None,
|
logprob_ref_chosen if logprob_ref_chosen is not None else None,
|
||||||
|
@ -338,7 +331,9 @@ class DPOTrainer(SLTrainer):
|
||||||
reject_loss_mask[:, 1:],
|
reject_loss_mask[:, 1:],
|
||||||
)
|
)
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
||||||
loss = losses.mean()
|
|
||||||
|
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||||
|
# sync
|
||||||
loss_mean = all_reduce_mean(tensor=loss)
|
loss_mean = all_reduce_mean(tensor=loss)
|
||||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||||
|
@ -347,16 +342,301 @@ class DPOTrainer(SLTrainer):
|
||||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||||
self.accumulative_meter.add(
|
|
||||||
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
|
|
||||||
)
|
|
||||||
step_bar.update()
|
|
||||||
|
|
||||||
msg = "Evaluation Result:\n"
|
if (i + 1) % self.accumulation_steps == 0:
|
||||||
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
|
self.optimizer.step()
|
||||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
self.optimizer.zero_grad()
|
||||||
self.coordinator.print_on_master(msg)
|
self.actor_scheduler.step()
|
||||||
os.makedirs(self.save_dir, exist_ok=True)
|
|
||||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
step_bar.set_postfix(
|
||||||
f.write(msg)
|
{
|
||||||
|
"train/loss": self.accumulative_meter.get("loss"),
|
||||||
|
"train/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
|
||||||
|
"train/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
"train/accuracy": self.accumulative_meter.get("accuracy"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
step_bar.update()
|
||||||
|
if self.writer and is_rank_0():
|
||||||
|
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||||
|
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/rejected_rewards",
|
||||||
|
self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
self.num_train_step,
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/margin",
|
||||||
|
self.accumulative_meter.get("chosen_rewards")
|
||||||
|
- self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
self.num_train_step,
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/accuracy",
|
||||||
|
self.accumulative_meter.get("accuracy"),
|
||||||
|
self.num_train_step,
|
||||||
|
)
|
||||||
|
self.num_train_step += 1
|
||||||
|
self.accumulative_meter.reset()
|
||||||
|
|
||||||
|
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
|
||||||
|
# save checkpoint
|
||||||
|
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||||
|
save_checkpoint(
|
||||||
|
save_dir=self.save_dir,
|
||||||
|
booster=self.booster,
|
||||||
|
model=self.model,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
lr_scheduler=self.actor_scheduler,
|
||||||
|
epoch=epoch,
|
||||||
|
step=self.num_train_step,
|
||||||
|
batch_size=batch_size,
|
||||||
|
coordinator=self.coordinator,
|
||||||
|
)
|
||||||
|
self.coordinator.print_on_master(
|
||||||
|
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
step_bar.close()
|
||||||
|
|
||||||
|
def _eval(self, epoch: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
epoch int: the number of current epoch
|
||||||
|
"""
|
||||||
|
if self.eval_dataloader is None:
|
||||||
|
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
|
||||||
|
return
|
||||||
|
self.model.eval()
|
||||||
|
self.ref_model.eval()
|
||||||
|
self.accumulative_meter.reset()
|
||||||
|
self.coordinator.print_on_master("\nStart evaluation...")
|
||||||
|
|
||||||
|
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||||
|
step_bar = tqdm(
|
||||||
|
range(len(self.eval_dataloader)),
|
||||||
|
desc="Step",
|
||||||
|
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
for _, batch in enumerate(self.eval_dataloader):
|
||||||
|
batch = to_device(batch, self.device)
|
||||||
|
(
|
||||||
|
chosen_input_ids,
|
||||||
|
chosen_attention_mask,
|
||||||
|
chosen_loss_mask,
|
||||||
|
reject_input_ids,
|
||||||
|
reject_attention_mask,
|
||||||
|
reject_loss_mask,
|
||||||
|
) = (
|
||||||
|
batch["chosen_input_ids"],
|
||||||
|
batch["chosen_attention_mask"],
|
||||||
|
batch["chosen_loss_mask"],
|
||||||
|
batch["reject_input_ids"],
|
||||||
|
batch["reject_attention_mask"],
|
||||||
|
batch["reject_loss_mask"],
|
||||||
|
)
|
||||||
|
batch_size = chosen_input_ids.size()[0]
|
||||||
|
# Calculate logits from reference model.
|
||||||
|
if self.ref_model is not None:
|
||||||
|
self.ref_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_all_logits = self.ref_model(
|
||||||
|
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
|
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
)["logits"]
|
||||||
|
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||||
|
ref_reject_logits = ref_all_logits[batch_size:]
|
||||||
|
logprob_ref_chosen = calc_masked_log_probs(
|
||||||
|
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
logprob_ref_reject = calc_masked_log_probs(
|
||||||
|
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprob_ref_chosen = None
|
||||||
|
logprob_ref_reject = None
|
||||||
|
|
||||||
|
# Merge chosen and reject
|
||||||
|
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
|
||||||
|
attention_mask = torch.stack(
|
||||||
|
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
|
||||||
|
)
|
||||||
|
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
|
||||||
|
logprob_ref = torch.stack(
|
||||||
|
[item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup]
|
||||||
|
)
|
||||||
|
|
||||||
|
data_iter = iter(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"input_ids": inputs_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"loss_mask": loss_mask,
|
||||||
|
"logprob_ref": logprob_ref,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
rewards = []
|
||||||
|
|
||||||
|
def _criterion(outputs, inputs):
|
||||||
|
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||||
|
calc_masked_log_probs(
|
||||||
|
outputs["logits"][0::2],
|
||||||
|
inputs["input_ids"][0::2],
|
||||||
|
inputs["loss_mask"][0::2][:, 1:],
|
||||||
|
self.length_normalization,
|
||||||
|
),
|
||||||
|
calc_masked_log_probs(
|
||||||
|
outputs["logits"][1::2],
|
||||||
|
inputs["input_ids"][1::2],
|
||||||
|
inputs["loss_mask"][1::2][:, 1:],
|
||||||
|
self.length_normalization,
|
||||||
|
),
|
||||||
|
inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
|
||||||
|
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
|
||||||
|
inputs["loss_mask"][0::2][:, 1:],
|
||||||
|
inputs["loss_mask"][1::2][:, 1:],
|
||||||
|
)
|
||||||
|
rewards.append(chosen_rewards)
|
||||||
|
rewards.append(rejected_rewards)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
outputs = self.booster.execute_pipeline(
|
||||||
|
data_iter,
|
||||||
|
self.model,
|
||||||
|
criterion=_criterion,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
)
|
||||||
|
loss = outputs["loss"]
|
||||||
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
|
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
|
||||||
|
global_loss = all_reduce_mean(loss, self.plugin)
|
||||||
|
chosen_rewards_mean = all_reduce_mean(chosen_rewards, self.plugin)
|
||||||
|
rejected_rewards_mean = all_reduce_mean(rejected_rewards, self.plugin)
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
step_bar.set_postfix(
|
||||||
|
{
|
||||||
|
"eval/loss": global_loss.item(),
|
||||||
|
"eval/lr": self.actor_scheduler.get_last_lr()[0],
|
||||||
|
"eval/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
|
||||||
|
"eval/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.accumulative_meter.add(
|
||||||
|
"chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()
|
||||||
|
)
|
||||||
|
self.accumulative_meter.add(
|
||||||
|
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
|
||||||
|
)
|
||||||
|
self.accumulative_meter.add("loss", global_loss.to(torch.float16).item())
|
||||||
|
step_bar.update()
|
||||||
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
|
msg = "\nEvaluation Result:\n"
|
||||||
|
for tag in ["loss", "chosen_rewards", "rejected_rewards"]:
|
||||||
|
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
print(msg)
|
||||||
|
else:
|
||||||
|
step_bar = trange(
|
||||||
|
len(self.eval_dataloader),
|
||||||
|
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||||
|
disable=not is_rank_0(),
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, batch in enumerate(self.eval_dataloader):
|
||||||
|
batch = to_device(batch, self.device)
|
||||||
|
(
|
||||||
|
chosen_input_ids,
|
||||||
|
chosen_attention_mask,
|
||||||
|
chosen_loss_mask,
|
||||||
|
reject_input_ids,
|
||||||
|
reject_attention_mask,
|
||||||
|
reject_loss_mask,
|
||||||
|
) = (
|
||||||
|
batch["chosen_input_ids"],
|
||||||
|
batch["chosen_attention_mask"],
|
||||||
|
batch["chosen_loss_mask"],
|
||||||
|
batch["reject_input_ids"],
|
||||||
|
batch["reject_attention_mask"],
|
||||||
|
batch["reject_loss_mask"],
|
||||||
|
)
|
||||||
|
if not self.apply_loss_mask:
|
||||||
|
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||||
|
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||||
|
|
||||||
|
batch_size = chosen_input_ids.size()[0]
|
||||||
|
|
||||||
|
actor_all_logits = self.model(
|
||||||
|
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
|
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
)["logits"]
|
||||||
|
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||||
|
actor_reject_logits = actor_all_logits[batch_size:]
|
||||||
|
|
||||||
|
logprob_actor_chosen = calc_masked_log_probs(
|
||||||
|
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
|
||||||
|
logprob_actor_reject = calc_masked_log_probs(
|
||||||
|
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
ref_all_logits = self.ref_model(
|
||||||
|
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
|
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
)["logits"]
|
||||||
|
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||||
|
ref_reject_logits = ref_all_logits[batch_size:]
|
||||||
|
logprob_ref_chosen = calc_masked_log_probs(
|
||||||
|
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
logprob_ref_reject = calc_masked_log_probs(
|
||||||
|
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||||
|
logprob_actor_chosen,
|
||||||
|
logprob_actor_reject,
|
||||||
|
logprob_ref_chosen if logprob_ref_chosen is not None else None,
|
||||||
|
logprob_ref_reject if logprob_ref_reject is not None else None,
|
||||||
|
chosen_loss_mask[:, 1:],
|
||||||
|
reject_loss_mask[:, 1:],
|
||||||
|
)
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
||||||
|
loss = losses.mean()
|
||||||
|
loss_mean = all_reduce_mean(tensor=loss)
|
||||||
|
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||||
|
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||||
|
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
||||||
|
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||||
|
self.accumulative_meter.add(
|
||||||
|
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
|
||||||
|
)
|
||||||
|
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||||
|
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||||
|
self.accumulative_meter.add(
|
||||||
|
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
|
||||||
|
)
|
||||||
|
step_bar.set_postfix(
|
||||||
|
{
|
||||||
|
"eval/loss": self.accumulative_meter.get("loss"),
|
||||||
|
"eval/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
|
||||||
|
"eval/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
"eval/accuracy": self.accumulative_meter.get("accuracy"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
step_bar.update()
|
||||||
|
|
||||||
|
msg = "\nEvaluation Result:\n"
|
||||||
|
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
|
||||||
|
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||||
|
self.coordinator.print_on_master(msg)
|
||||||
|
if self.save_dir is not None:
|
||||||
|
os.makedirs(self.save_dir, exist_ok=True)
|
||||||
|
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||||
|
f.write(msg)
|
||||||
step_bar.close()
|
step_bar.close()
|
||||||
|
|
|
@ -73,8 +73,7 @@ def main():
|
||||||
"--conversation_template_config",
|
"--conversation_template_config",
|
||||||
type=str,
|
type=str,
|
||||||
default="conversation_template_config",
|
default="conversation_template_config",
|
||||||
help="Path \
|
help="Path to save conversation template config files.",
|
||||||
to save conversation template config files.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
|
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -13,7 +13,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
|
@ -29,8 +29,6 @@ def train(args):
|
||||||
# check lora compatibility
|
# check lora compatibility
|
||||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
|
||||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Initialize Distributed Training
|
# Initialize Distributed Training
|
||||||
|
@ -46,7 +44,7 @@ def train(args):
|
||||||
Default torch ddp plugin without any acceleration, for
|
Default torch ddp plugin without any acceleration, for
|
||||||
debugging purpose acceleration, for debugging purpose
|
debugging purpose acceleration, for debugging purpose
|
||||||
"""
|
"""
|
||||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
|
||||||
elif args.plugin == "gemini":
|
elif args.plugin == "gemini":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
|
@ -56,14 +54,6 @@ def train(args):
|
||||||
enable_gradient_accumulation=True,
|
enable_gradient_accumulation=True,
|
||||||
enable_flash_attention=args.use_flash_attn,
|
enable_flash_attention=args.use_flash_attn,
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
|
||||||
plugin = GeminiPlugin(
|
|
||||||
precision=args.mixed_precision,
|
|
||||||
placement_policy="auto",
|
|
||||||
initial_scale=2**16,
|
|
||||||
max_norm=args.grad_clip,
|
|
||||||
enable_flash_attention=args.use_flash_attn,
|
|
||||||
)
|
|
||||||
elif args.plugin == "zero2":
|
elif args.plugin == "zero2":
|
||||||
plugin = LowLevelZeroPlugin(
|
plugin = LowLevelZeroPlugin(
|
||||||
stage=2,
|
stage=2,
|
||||||
|
@ -92,20 +82,24 @@ def train(args):
|
||||||
parallel_output=False,
|
parallel_output=False,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
|
microbatch_size=args.microbatch_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
ref_booster = Booster(plugin=plugin)
|
|
||||||
|
|
||||||
# ======================================================
|
ref_plugin = HybridParallelPlugin(
|
||||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
tp_size=args.ref_tp,
|
||||||
# ======================================================
|
pp_size=1,
|
||||||
# Temp Fix: Disable lazy init due to version conflict
|
zero_stage=args.zero_stage,
|
||||||
# init_ctx = (
|
enable_flash_attention=args.use_flash_attn,
|
||||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||||
# )
|
parallel_output=False,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
)
|
||||||
|
ref_booster = Booster(plugin=ref_plugin)
|
||||||
|
|
||||||
init_ctx = nullcontext()
|
init_ctx = nullcontext()
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
|
@ -130,6 +124,7 @@ def train(args):
|
||||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||||
else:
|
else:
|
||||||
ref_model = None
|
ref_model = None
|
||||||
|
|
||||||
if args.lora_config is not None:
|
if args.lora_config is not None:
|
||||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
@ -139,7 +134,9 @@ def train(args):
|
||||||
disable_dropout(ref_model)
|
disable_dropout(ref_model)
|
||||||
|
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
# Make sure gradient checkpointing can be activated.
|
||||||
|
model.train()
|
||||||
|
# Note, for some models, lora may not be compatible with gradient checkpointing.
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
|
|
||||||
|
@ -169,7 +166,7 @@ def train(args):
|
||||||
adamw_mode=True,
|
adamw_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# configure dataset
|
# Configure dataset
|
||||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||||
|
@ -213,14 +210,15 @@ def train(args):
|
||||||
|
|
||||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||||
torch.set_default_dtype(default_dtype)
|
torch.set_default_dtype(default_dtype)
|
||||||
|
|
||||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optim,
|
optimizer=optim,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
dataloader=train_dataloader,
|
dataloader=train_dataloader,
|
||||||
)
|
)
|
||||||
if ref_model is not None:
|
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model)
|
||||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
|
|
||||||
torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
|
|
||||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||||
|
@ -312,7 +310,7 @@ if __name__ == "__main__":
|
||||||
"--plugin",
|
"--plugin",
|
||||||
type=str,
|
type=str,
|
||||||
default="gemini",
|
default="gemini",
|
||||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
choices=["gemini", "zero2", "zero2_cpu", "3d", "ddp"],
|
||||||
help="Choose which plugin to use",
|
help="Choose which plugin to use",
|
||||||
)
|
)
|
||||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||||
|
@ -342,22 +340,35 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||||
parser.add_argument("--max_epochs", type=int, default=3)
|
parser.add_argument("--max_epochs", type=int, default=3)
|
||||||
parser.add_argument("--batch_size", type=int, default=4)
|
parser.add_argument("--batch_size", type=int, default=4)
|
||||||
|
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||||
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||||
|
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||||
|
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||||
|
parser.add_argument("--lr", type=float, default=5e-6)
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||||
|
parser.add_argument("--log_dir", default=None, type=str)
|
||||||
|
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||||
|
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||||
|
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--microbatch_size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Micro batch size for PP training. To activate PP training for DPO-like algorithm, you must keep size even and the size should be equal or greater than 2.",
|
||||||
|
)
|
||||||
|
# Parameter for reference model
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable_reference_model",
|
"--disable_reference_model",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Disable the reference model (enabled by default)",
|
help="Disable the reference model (enabled by default)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
parser.add_argument(
|
||||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
"--ref_tp",
|
||||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
type=int,
|
||||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
default=1,
|
||||||
parser.add_argument("--lr", type=float, default=5e-6)
|
help="TP size for reference model; used only when reference model is too large.",
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
)
|
||||||
parser.add_argument("--log_dir", default=None, type=str)
|
|
||||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
|
||||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
|
||||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# fool proof hyperparameter setup
|
# fool proof hyperparameter setup
|
||||||
|
|
|
@ -68,7 +68,7 @@ def train(args):
|
||||||
Default torch ddp plugin without any acceleration, for
|
Default torch ddp plugin without any acceleration, for
|
||||||
debugging purpose acceleration, for debugging purpose
|
debugging purpose acceleration, for debugging purpose
|
||||||
"""
|
"""
|
||||||
plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False)
|
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
|
||||||
elif args.plugin == "gemini":
|
elif args.plugin == "gemini":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
|
|
|
@ -4,7 +4,7 @@ BASE_TEMP_DIR=$BASE_DIR/temp
|
||||||
EXAMPLES_DIR=$BASE_DIR/examples
|
EXAMPLES_DIR=$BASE_DIR/examples
|
||||||
TEST_DATA_DIR=$BASE_DIR/tests/test_data
|
TEST_DATA_DIR=$BASE_DIR/tests/test_data
|
||||||
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
|
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
|
||||||
CONFIG_DIR=$BASE_DIR/config
|
CONFIG_DIR=$BASE_DIR/conversation_template
|
||||||
|
|
||||||
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
|
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
|
||||||
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
|
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
|
||||||
|
@ -39,23 +39,23 @@ get_pretrain() {
|
||||||
get_conversation_template_config() {
|
get_conversation_template_config() {
|
||||||
local model=$1
|
local model=$1
|
||||||
if [[ $model == "colossal-llama2" ]]; then
|
if [[ $model == "colossal-llama2" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
|
echo "$CONFIG_DIR/colossal-llama2.json"
|
||||||
elif [[ $model == "llama2" ]]; then
|
elif [[ $model == "llama2" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/llama2.json"
|
echo "$CONFIG_DIR/llama2.json"
|
||||||
elif [[ $model == "deepseek" ]]; then
|
elif [[ $model == "deepseek" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
|
echo "$CONFIG_DIR/deepseek-ai_DeepSeek-V2-Lite.json"
|
||||||
elif [[ $model == "mistral" ]]; then
|
elif [[ $model == "mistral" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
|
echo "$CONFIG_DIR/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
|
||||||
elif [[ $model == "chatGLM2" ]]; then
|
elif [[ $model == "chatGLM2" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
|
echo "$CONFIG_DIR/THUDM_chatglm2-6b.json"
|
||||||
elif [[ $model == "chatGLM3" ]]; then
|
elif [[ $model == "chatGLM3" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
|
echo "$CONFIG_DIR/THUDM_chatglm3-6b.json"
|
||||||
elif [[ $model == "phi" ]]; then
|
elif [[ $model == "phi" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
|
echo "$CONFIG_DIR/microsoft_phi-2.json"
|
||||||
elif [[ $model == "Yi" ]]; then
|
elif [[ $model == "Yi" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
|
echo "$CONFIG_DIR/01-ai_Yi-1.5-9B-Chat.json"
|
||||||
elif [[ $model == "baichuan" ]]; then
|
elif [[ $model == "baichuan" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
|
echo "$CONFIG_DIR/baichuan-inc_Baichuan2-13B-Chat.json"
|
||||||
else
|
else
|
||||||
echo "Unknown model $model"
|
echo "Unknown model $model"
|
||||||
exit 1
|
exit 1
|
||||||
|
@ -71,6 +71,7 @@ for model in ${MODELS[@]}; do
|
||||||
rm -rf $SAVE_DIR/arrow
|
rm -rf $SAVE_DIR/arrow
|
||||||
pretrain=$(get_pretrain $model)
|
pretrain=$(get_pretrain $model)
|
||||||
conversation_template_config=$(get_conversation_template_config $model)
|
conversation_template_config=$(get_conversation_template_config $model)
|
||||||
|
echo $conversation_template_config
|
||||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
|
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
|
||||||
--tokenizer_dir $pretrain \
|
--tokenizer_dir $pretrain \
|
||||||
--conversation_template_config $conversation_template_config \
|
--conversation_template_config $conversation_template_config \
|
||||||
|
|
|
@ -279,4 +279,4 @@ class CudaAccelerator(BaseAccelerator):
|
||||||
"""
|
"""
|
||||||
Return autocast function
|
Return autocast function
|
||||||
"""
|
"""
|
||||||
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
|
return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
|
||||||
|
|
|
@ -322,7 +322,6 @@ class GeminiPlugin(DPPluginBase):
|
||||||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
|
||||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||||
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
||||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
|
@ -366,7 +365,6 @@ class GeminiPlugin(DPPluginBase):
|
||||||
enable_flash_attention: bool = False,
|
enable_flash_attention: bool = False,
|
||||||
enable_sequence_parallelism: bool = False,
|
enable_sequence_parallelism: bool = False,
|
||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
enable_sequence_overlap: bool = False,
|
|
||||||
enable_async_reduce: bool = True,
|
enable_async_reduce: bool = True,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
@ -428,7 +426,6 @@ class GeminiPlugin(DPPluginBase):
|
||||||
self.enable_flash_attention = enable_flash_attention
|
self.enable_flash_attention = enable_flash_attention
|
||||||
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
|
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
|
||||||
self.enable_jit_fused = enable_jit_fused
|
self.enable_jit_fused = enable_jit_fused
|
||||||
self.enable_sequence_overlap = enable_sequence_overlap
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
|
@ -455,7 +452,6 @@ class GeminiPlugin(DPPluginBase):
|
||||||
enable_flash_attention=self.enable_flash_attention,
|
enable_flash_attention=self.enable_flash_attention,
|
||||||
enable_jit_fused=self.enable_jit_fused,
|
enable_jit_fused=self.enable_jit_fused,
|
||||||
enable_sequence_parallelism=self.enable_sequence_parallelism,
|
enable_sequence_parallelism=self.enable_sequence_parallelism,
|
||||||
enable_sequence_overlap=self.enable_sequence_overlap,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|
|
@ -116,10 +116,15 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
|
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self.op_hooks = []
|
self.op_hooks = []
|
||||||
|
if use_fp8:
|
||||||
|
self.op_hooks.append(FP8Hook())
|
||||||
|
self.op_hooks = []
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
self.op_hooks.append(FP8Hook())
|
self.op_hooks.append(FP8Hook())
|
||||||
if overlap_allgather:
|
if overlap_allgather:
|
||||||
self.op_hooks.append(ZeroOpHook())
|
self.op_hooks.append(ZeroOpHook())
|
||||||
|
if use_fp8 or overlap_allgather:
|
||||||
|
self.op_hooks.append(ZeroOpHook())
|
||||||
if use_fp8 or overlap_allgather:
|
if use_fp8 or overlap_allgather:
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
if p.requires_grad and type(p) is not ColoParameter:
|
if p.requires_grad and type(p) is not ColoParameter:
|
||||||
|
@ -232,6 +237,9 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
def _hook_context(self):
|
def _hook_context(self):
|
||||||
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
|
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
|
||||||
|
|
||||||
|
def _hook_context(self):
|
||||||
|
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
|
||||||
|
|
||||||
|
|
||||||
def get_param_info(optim: Optimizer):
|
def get_param_info(optim: Optimizer):
|
||||||
# Get a backup of necessary information of parameters for future use, which includes:
|
# Get a backup of necessary information of parameters for future use, which includes:
|
||||||
|
@ -951,7 +959,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||||
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
|
||||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||||
|
@ -983,6 +990,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||||
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
|
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||||
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
|
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
|
||||||
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
|
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
|
||||||
|
@ -1002,7 +1011,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
enable_sequence_parallelism: bool = False,
|
enable_sequence_parallelism: bool = False,
|
||||||
sequence_parallelism_mode: str = None,
|
sequence_parallelism_mode: str = None,
|
||||||
enable_sequence_overlap: bool = False,
|
|
||||||
parallel_output: bool = True,
|
parallel_output: bool = True,
|
||||||
num_microbatches: Optional[int] = None,
|
num_microbatches: Optional[int] = None,
|
||||||
microbatch_size: Optional[int] = None,
|
microbatch_size: Optional[int] = None,
|
||||||
|
@ -1092,6 +1100,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.use_fp8 = use_fp8
|
self.use_fp8 = use_fp8
|
||||||
if dp_outside:
|
if dp_outside:
|
||||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||||
|
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||||
if sequence_parallelism_mode == "ring_attn":
|
if sequence_parallelism_mode == "ring_attn":
|
||||||
# Swap tp and sp since 2D Ring has better inter-node latency
|
# Swap tp and sp since 2D Ring has better inter-node latency
|
||||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
|
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
|
||||||
|
@ -1195,13 +1204,15 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
enable_jit_fused=self.enable_jit_fused,
|
enable_jit_fused=self.enable_jit_fused,
|
||||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||||
sequence_parallelism_mode=sequence_parallelism_mode,
|
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||||
enable_sequence_overlap=enable_sequence_overlap,
|
|
||||||
parallel_output=parallel_output,
|
parallel_output=parallel_output,
|
||||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||||
fp8_communication=fp8_communication,
|
fp8_communication=fp8_communication,
|
||||||
inner_ring_size=inner_ring_size,
|
inner_ring_size=inner_ring_size,
|
||||||
|
pg_mesh=self.pg_mesh,
|
||||||
|
sp_axis=self.sp_axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.amp_config = dict(
|
self.amp_config = dict(
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
growth_factor=growth_factor,
|
growth_factor=growth_factor,
|
||||||
|
@ -1293,6 +1304,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.dp_size == 1 and self.pp_size == 1
|
self.dp_size == 1 and self.pp_size == 1
|
||||||
)
|
)
|
||||||
# sync gradients across DP * SP ranks
|
# sync gradients across DP * SP ranks
|
||||||
|
# sync gradients across DP * SP ranks
|
||||||
# Apply Hybrid ZeRO across DP * SP ranks
|
# Apply Hybrid ZeRO across DP * SP ranks
|
||||||
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
||||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||||
|
|
|
@ -290,7 +290,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
peft_model, PeftModel
|
peft_model, PeftModel
|
||||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
return peft_model.save_pretrained(
|
||||||
|
checkpoint,
|
||||||
|
safe_serialization=use_safetensors,
|
||||||
|
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroPlugin(DPPluginBase):
|
class LowLevelZeroPlugin(DPPluginBase):
|
||||||
|
|
|
@ -141,7 +141,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||||
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
|
||||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||||
|
@ -190,7 +189,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
enable_sequence_parallelism: bool = False,
|
enable_sequence_parallelism: bool = False,
|
||||||
sequence_parallelism_mode: str = None,
|
sequence_parallelism_mode: str = None,
|
||||||
enable_sequence_overlap: bool = False,
|
|
||||||
parallel_output: bool = True,
|
parallel_output: bool = True,
|
||||||
num_microbatches: Optional[int] = None,
|
num_microbatches: Optional[int] = None,
|
||||||
microbatch_size: Optional[int] = None,
|
microbatch_size: Optional[int] = None,
|
||||||
|
@ -368,7 +366,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
enable_jit_fused=self.enable_jit_fused,
|
enable_jit_fused=self.enable_jit_fused,
|
||||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||||
sequence_parallelism_mode=sequence_parallelism_mode,
|
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||||
enable_sequence_overlap=enable_sequence_overlap,
|
|
||||||
parallel_output=parallel_output,
|
parallel_output=parallel_output,
|
||||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||||
|
@ -134,7 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
peft_model, PeftModel
|
peft_model, PeftModel
|
||||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||||
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
|
return peft_model.save_pretrained(
|
||||||
|
checkpoint,
|
||||||
|
safe_serialization=use_safetensors,
|
||||||
|
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TorchDDPModel(ModelWrapper):
|
class TorchDDPModel(ModelWrapper):
|
||||||
|
|
|
@ -11,6 +11,7 @@ import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
@ -20,7 +21,7 @@ from colossalai.tensor.padded_tensor import (
|
||||||
to_padded_tensor,
|
to_padded_tensor,
|
||||||
to_unpadded_tensor,
|
to_unpadded_tensor,
|
||||||
)
|
)
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
||||||
|
|
||||||
from .general_checkpoint_io import GeneralCheckpointIO
|
from .general_checkpoint_io import GeneralCheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
|
@ -104,8 +105,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
yield block, block_size
|
yield block, block_size
|
||||||
|
|
||||||
# Save buffers.
|
# Save buffers.
|
||||||
|
non_persist_buffers_set = get_non_persistent_buffers_set(model)
|
||||||
for name, buf in model.named_buffers():
|
for name, buf in model.named_buffers():
|
||||||
if buf is not None and name not in model._non_persistent_buffers_set:
|
if buf is not None and name not in non_persist_buffers_set:
|
||||||
buffer = buf if keep_vars else buf.detach()
|
buffer = buf if keep_vars else buf.detach()
|
||||||
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
||||||
if block is not None:
|
if block is not None:
|
||||||
|
@ -351,9 +353,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
_load(name)
|
_load(name)
|
||||||
|
|
||||||
# Load buffers.
|
# Load buffers.
|
||||||
non_persistent_buffers = set()
|
non_persistent_buffers = get_non_persistent_buffers_set(model)
|
||||||
for n, m in model.named_modules():
|
|
||||||
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
|
|
||||||
for name, buf in model.named_buffers():
|
for name, buf in model.named_buffers():
|
||||||
if buf is not None and name not in non_persistent_buffers:
|
if buf is not None and name not in non_persistent_buffers:
|
||||||
_load(name)
|
_load(name)
|
||||||
|
@ -956,4 +956,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
peft_model, PeftModel
|
peft_model, PeftModel
|
||||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
return peft_model.save_pretrained(
|
||||||
|
checkpoint,
|
||||||
|
safe_serialization=use_safetensors,
|
||||||
|
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||||
|
)
|
||||||
|
|
|
@ -57,7 +57,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
||||||
target_module=NopadBaichuanMLP,
|
target_module=NopadBaichuanMLP,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
|
suffix="self_attn.W_pack",
|
||||||
|
target_module=FusedLinear1D_Col,
|
||||||
|
kwargs={"split_sizes": [self.model.config.hidden_size] * 3},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
|
|
||||||
|
|
||||||
from .bias_dropout_add import bias_dropout_add_fused_train
|
from .bias_dropout_add import bias_dropout_add_fused_train
|
||||||
from .bias_gelu import bias_gelu_impl
|
from .bias_gelu import bias_gelu_impl
|
||||||
|
@ -45,6 +44,7 @@ def warmup_jit_fusion(
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
):
|
):
|
||||||
"""Compile JIT functions before the main training steps"""
|
"""Compile JIT functions before the main training steps"""
|
||||||
|
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
|
||||||
|
|
||||||
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
|
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
|
||||||
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
|
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
|
||||||
|
|
|
@ -3,8 +3,9 @@ from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
from packaging.version import Version
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
|
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
# this register are for torch under version 1.13.1, maybe removed in the future
|
# this register are for torch under version 1.13.1, maybe removed in the future
|
||||||
|
@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
|
||||||
return OrderedDict((key, value) for key, value in zip(context, values))
|
return OrderedDict((key, value) for key, value in zip(context, values))
|
||||||
|
|
||||||
|
|
||||||
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
|
if Version(torch.__version__) <= Version("1.13.1"):
|
||||||
|
try:
|
||||||
|
from torch.utils._pytree import register_pytree_node as _register_pytree_node
|
||||||
|
except ImportError:
|
||||||
|
from torch.utils._pytree import _register_pytree_node
|
||||||
|
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
|
||||||
|
|
||||||
|
|
||||||
def tree_map_hf(fn: Any, pytree: Any):
|
def tree_map_hf(fn: Any, pytree: Any):
|
||||||
|
|
|
@ -351,15 +351,16 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
if output_obj_grad is None:
|
if output_obj_grad is None:
|
||||||
optimizer.backward(output_obj)
|
optimizer.backward(output_obj)
|
||||||
else:
|
else:
|
||||||
if "backward_tensor_keys" not in output_obj:
|
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
|
||||||
for k, grad in output_obj_grad.items():
|
tensors_to_backward = []
|
||||||
optimizer.backward_by_grad(output_obj[k], grad)
|
grads_to_backward = []
|
||||||
|
for k in keys:
|
||||||
|
tensors_to_backward.append(output_obj[k])
|
||||||
|
grads_to_backward.append(output_obj_grad[k])
|
||||||
|
if len(tensors_to_backward) == 1:
|
||||||
|
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
|
||||||
else:
|
else:
|
||||||
for k, grad in output_obj_grad.items():
|
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
|
||||||
output_obj[k].grad = grad
|
|
||||||
for k in output_obj["backward_tensor_keys"]:
|
|
||||||
tensor_to_backward = output_obj[k]
|
|
||||||
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
|
|
||||||
|
|
||||||
# Collect the grad of the input_obj.
|
# Collect the grad of the input_obj.
|
||||||
input_obj_grad = None
|
input_obj_grad = None
|
||||||
|
|
|
@ -305,15 +305,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
if output_obj_grad is None:
|
if output_obj_grad is None:
|
||||||
optimizer.backward(output_obj)
|
optimizer.backward(output_obj)
|
||||||
else:
|
else:
|
||||||
if "backward_tensor_keys" not in output_obj:
|
keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
|
||||||
for k, grad in output_obj_grad.items():
|
tensors_to_backward = []
|
||||||
optimizer.backward_by_grad(output_obj[k], grad)
|
grads_to_backward = []
|
||||||
|
for k in keys:
|
||||||
|
tensors_to_backward.append(output_obj[k])
|
||||||
|
grads_to_backward.append(output_obj_grad[k])
|
||||||
|
if len(tensors_to_backward) == 1:
|
||||||
|
optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
|
||||||
else:
|
else:
|
||||||
for k, grad in output_obj_grad.items():
|
optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
|
||||||
output_obj[k].grad = grad
|
|
||||||
for k in output_obj["backward_tensor_keys"]:
|
|
||||||
tensor_to_backward = output_obj[k]
|
|
||||||
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
|
|
||||||
|
|
||||||
# Collect the grad of the input_obj.
|
# Collect the grad of the input_obj.
|
||||||
input_obj_grad = None
|
input_obj_grad = None
|
||||||
|
|
|
@ -8,6 +8,8 @@ import torch.nn.functional as F
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
|
from .fp8_config import dynamic_kernel
|
||||||
|
|
||||||
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
|
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
|
||||||
SCALE_BYTES = 4
|
SCALE_BYTES = 4
|
||||||
try:
|
try:
|
||||||
|
@ -832,11 +834,13 @@ class _LinearFp8(torch.autograd.Function):
|
||||||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False)
|
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
|
||||||
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
return _LinearFp8.apply(input, weight, bias)
|
return _LinearFp8.apply(input, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0:
|
||||||
|
return F.linear(input, weight, bias)
|
||||||
out = _linear_fp8(input, weight, bias)
|
out = _linear_fp8(input, weight, bias)
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
dynamic_kernel: bool = False
|
|
@ -6,7 +6,7 @@ from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHe
|
||||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Embedding1D",
|
"Embedding1D",
|
||||||
|
@ -35,4 +35,5 @@ __all__ = [
|
||||||
"RingAttention",
|
"RingAttention",
|
||||||
"get_pad_info",
|
"get_pad_info",
|
||||||
"all_to_all_comm",
|
"all_to_all_comm",
|
||||||
|
"FusedLinear1D_Row",
|
||||||
]
|
]
|
||||||
|
|
|
@ -106,7 +106,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
||||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
total_input = total_input.view(-1, total_input.shape[-1])
|
total_input = total_input.view(-1, total_input.shape[-1])
|
||||||
|
|
||||||
if ctx.async_grad_allreduce and fp8_communication:
|
if fp8_communication or not ctx.async_grad_allreduce:
|
||||||
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
|
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
|
||||||
elif ctx.async_grad_allreduce:
|
elif ctx.async_grad_allreduce:
|
||||||
# Asynchronous all-reduce
|
# Asynchronous all-reduce
|
||||||
|
@ -364,10 +364,12 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
|
||||||
for k in recv_tensors:
|
for k in recv_tensors:
|
||||||
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
|
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
|
||||||
|
|
||||||
|
input_tensors = []
|
||||||
output_tensors = []
|
output_tensors = []
|
||||||
|
|
||||||
handles = communicate_step()
|
handles = communicate_step()
|
||||||
# first round: special case, retrive from local tensor
|
# first round: special case, retrive from local tensor
|
||||||
|
input_tensors.append(input_to_gather)
|
||||||
output_tensors.append(func(**input_to_gather, **input_local))
|
output_tensors.append(func(**input_to_gather, **input_local))
|
||||||
for i in range(group_size - 2):
|
for i in range(group_size - 2):
|
||||||
for handle in handles:
|
for handle in handles:
|
||||||
|
@ -378,14 +380,25 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
|
||||||
handles = communicate_step()
|
handles = communicate_step()
|
||||||
|
|
||||||
# actual computation
|
# actual computation
|
||||||
|
input_tensors.append(send_tensors)
|
||||||
output_tensors.append(func(**send_tensors, **input_local))
|
output_tensors.append(func(**send_tensors, **input_local))
|
||||||
|
|
||||||
# final round: special case, no need to send/recv again
|
# final round: special case, no need to send/recv again
|
||||||
for handle in handles:
|
for handle in handles:
|
||||||
handle.wait()
|
handle.wait()
|
||||||
|
input_tensors.append(send_tensors)
|
||||||
output_tensors.append(func(**recv_tensors, **input_local))
|
output_tensors.append(func(**recv_tensors, **input_local))
|
||||||
|
|
||||||
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
|
gathered_input = {}
|
||||||
|
for k in input_to_gather:
|
||||||
|
input_shards = [d[k] for d in input_tensors[group_size - cur_rank :] + input_tensors[: group_size - cur_rank]]
|
||||||
|
gathered_input[k] = torch.cat(input_shards, dim=gather_dim)
|
||||||
|
|
||||||
|
gathered_output = torch.cat(
|
||||||
|
output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
return gathered_output, gathered_input
|
||||||
|
|
||||||
|
|
||||||
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
|
@ -441,29 +454,30 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
|
||||||
ctx.save_for_backward(input_, weight, bias)
|
ctx.save_for_backward(input_, weight, bias)
|
||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
ctx.overlap = overlap
|
|
||||||
|
|
||||||
if ring is True:
|
if ring is True:
|
||||||
input_to_gather = {"input": input_}
|
input_to_gather = {"input": input_}
|
||||||
input_local = {"weight": weight}
|
input_local = {"weight": weight}
|
||||||
|
|
||||||
output = _ring_as_gather(
|
output, input_dict = _ring_as_gather(
|
||||||
F.linear,
|
F.linear,
|
||||||
input_to_gather=input_to_gather,
|
input_to_gather=input_to_gather,
|
||||||
input_local=input_local,
|
input_local=input_local,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
)
|
)
|
||||||
|
ctx.gathered_input = input_dict["input"]
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output += bias
|
output += bias
|
||||||
else:
|
else:
|
||||||
input_parallel = _gather(input_, dim, process_group)
|
input_parallel = _gather(input_, dim, process_group)
|
||||||
|
ctx.gathered_input = input_parallel
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output = F.linear(input_parallel, weight, bias)
|
output = F.linear(input_parallel, weight, bias)
|
||||||
else:
|
else:
|
||||||
|
@ -477,100 +491,50 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
use_bias = ctx.use_bias
|
use_bias = ctx.use_bias
|
||||||
dim = ctx.dim
|
dim = ctx.dim
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
overlap = ctx.overlap
|
|
||||||
|
|
||||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
||||||
if use_bias:
|
if use_bias:
|
||||||
bias = bias.view(bias.shape)
|
bias = bias.view(bias.shape)
|
||||||
|
|
||||||
if not overlap:
|
input_parallel = ctx.gathered_input
|
||||||
input_parallel = _gather(input_, dim, process_group)
|
|
||||||
|
|
||||||
total_input = input_parallel
|
total_input = input_parallel
|
||||||
grad_input = grad_output.matmul(weight)
|
grad_input = grad_output.matmul(weight)
|
||||||
grad_output = grad_output.contiguous()
|
grad_output = grad_output.contiguous()
|
||||||
# Convert the tensor shapes to 2D for execution compatibility
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
if len(grad_output.shape) > 2:
|
if len(grad_output.shape) > 2:
|
||||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
total_input = total_input.view(-1, total_input.shape[-1])
|
total_input = total_input.view(-1, total_input.shape[-1])
|
||||||
|
|
||||||
if ctx.async_grad_reduce_scatter:
|
if ctx.async_grad_reduce_scatter:
|
||||||
# Asynchronous reduce-scatter
|
# Asynchronous reduce-scatter
|
||||||
input_list = [
|
|
||||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
|
||||||
]
|
|
||||||
output = torch.empty(
|
|
||||||
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
|
||||||
).contiguous()
|
|
||||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
|
||||||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
|
||||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
|
||||||
|
|
||||||
if _grad_accum_fusion_available and weight.grad is not None:
|
|
||||||
grad = weight.grad
|
|
||||||
if grad.dtype == torch.float32:
|
|
||||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
|
||||||
grad_weight = None
|
|
||||||
elif grad.dtype == torch.float16:
|
|
||||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
|
||||||
grad_weight = None
|
|
||||||
else:
|
|
||||||
grad_weight = grad_output.t().matmul(total_input)
|
|
||||||
else:
|
|
||||||
grad_weight = grad_output.t().matmul(total_input)
|
|
||||||
|
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
||||||
|
|
||||||
if ctx.async_grad_reduce_scatter:
|
|
||||||
handle.wait()
|
|
||||||
|
|
||||||
else:
|
|
||||||
input_ = input_.contiguous()
|
|
||||||
world_size = dist.get_world_size(process_group)
|
|
||||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
||||||
|
|
||||||
# do all gather in is async way
|
|
||||||
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
|
||||||
# calculate gradient and prepare data asynchronously with all-gather
|
|
||||||
# calculate
|
|
||||||
grad_input = grad_output.matmul(weight)
|
|
||||||
grad_output = grad_output.contiguous()
|
|
||||||
# Convert the tensor shapes to 2D for execution compatibility
|
|
||||||
if len(grad_output.shape) > 2:
|
|
||||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
||||||
# prepare data
|
|
||||||
input_list = [
|
input_list = [
|
||||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||||
]
|
]
|
||||||
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
|
||||||
# wait until all-gather finished
|
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
gather_handle.wait()
|
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||||
|
|
||||||
# do reduce-scatter in async way
|
if _grad_accum_fusion_available and weight.grad is not None:
|
||||||
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
grad = weight.grad
|
||||||
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
if grad.dtype == torch.float32:
|
||||||
# calculate gradient
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||||
if len(input_parallel.shape) > 2:
|
grad_weight = None
|
||||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
elif grad.dtype == torch.float16:
|
||||||
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||||
if _grad_accum_fusion_available and weight.grad is not None:
|
grad_weight = None
|
||||||
grad = weight.grad
|
|
||||||
if grad.dtype == torch.float32:
|
|
||||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
|
|
||||||
grad_weight = None
|
|
||||||
elif grad.dtype == torch.float16:
|
|
||||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
|
|
||||||
grad_weight = None
|
|
||||||
else:
|
|
||||||
grad_weight = grad_output.t().matmul(input_parallel)
|
|
||||||
else:
|
else:
|
||||||
grad_weight = grad_output.t().matmul(input_parallel)
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
# grad_weight = grad_output.t().matmul(input_parallel)
|
else:
|
||||||
# wait until reduce-scatter finished
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
reducescatter_handle.wait()
|
|
||||||
|
|
||||||
return output, grad_weight, grad_bias, None, None, None, None, None
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
||||||
|
if ctx.async_grad_reduce_scatter:
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
return output, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def _ring_as_reducescatter(
|
def _ring_as_reducescatter(
|
||||||
|
@ -701,7 +665,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||||
# Convert the tensor shapes to 2D for execution compatibility
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
if len(grad_output.shape) > 2:
|
if len(grad_output.shape) > 2:
|
||||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
total_input = total_input.view(-1, total_input.shape[-1])
|
total_input = total_input.reshape(-1, total_input.shape[-1])
|
||||||
grad_weight = grad_output.t().matmul(total_input)
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
||||||
|
@ -759,34 +723,30 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
|
||||||
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
|
||||||
):
|
|
||||||
ctx.save_for_backward(input_, weight, bias)
|
ctx.save_for_backward(input_, weight, bias)
|
||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
ctx.overlap = overlap
|
|
||||||
ctx.fp8_communication = fp8_communication
|
ctx.fp8_communication = fp8_communication
|
||||||
|
|
||||||
if ring is True:
|
if ring is True:
|
||||||
input_to_gather = {}
|
input_to_gather = {"input": input_}
|
||||||
input_local = {}
|
input_local = {"other": weight}
|
||||||
input_to_gather["input"] = input_
|
|
||||||
input_local["other"] = weight
|
|
||||||
|
|
||||||
output = _ring_as_gather(
|
output, input_dict = _ring_as_gather(
|
||||||
torch.matmul,
|
torch.matmul,
|
||||||
input_to_gather=input_to_gather,
|
input_to_gather=input_to_gather,
|
||||||
input_local=input_local,
|
input_local=input_local,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
gather_dim=dim,
|
gather_dim=dim,
|
||||||
)
|
)
|
||||||
|
ctx.gathered_input = input_dict["input"]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
||||||
|
ctx.gathered_input = input_parallel
|
||||||
output = torch.matmul(input_parallel, weight)
|
output = torch.matmul(input_parallel, weight)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
|
@ -799,76 +759,39 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
use_bias = ctx.use_bias
|
use_bias = ctx.use_bias
|
||||||
dim = ctx.dim
|
dim = ctx.dim
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
overlap = ctx.overlap
|
|
||||||
fp8_communication = ctx.fp8_communication
|
|
||||||
|
|
||||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
||||||
weight = weight.view(weight.shape)
|
weight = weight.view(weight.shape)
|
||||||
if use_bias:
|
if use_bias:
|
||||||
bias = bias.view(bias.shape)
|
bias = bias.view(bias.shape)
|
||||||
|
|
||||||
if not overlap:
|
input_parallel = ctx.gathered_input
|
||||||
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
|
|
||||||
|
|
||||||
total_input = input_parallel
|
total_input = input_parallel
|
||||||
grad_input = grad_output.matmul(weight.T)
|
grad_input = grad_output.matmul(weight.T)
|
||||||
grad_output = grad_output.contiguous()
|
grad_output = grad_output.contiguous()
|
||||||
# Convert the tensor shapes to 2D for execution compatibility
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
if len(grad_output.shape) > 2:
|
if len(grad_output.shape) > 2:
|
||||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
total_input = total_input.view(-1, total_input.shape[-1])
|
total_input = total_input.view(-1, total_input.shape[-1])
|
||||||
|
|
||||||
if ctx.async_grad_reduce_scatter:
|
if ctx.async_grad_reduce_scatter:
|
||||||
# Asynchronous reduce-scatter
|
# Asynchronous reduce-scatter
|
||||||
input_list = [
|
|
||||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
|
||||||
]
|
|
||||||
output = torch.empty(
|
|
||||||
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
|
||||||
).contiguous()
|
|
||||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
|
||||||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
|
||||||
# all-reduce scheduled first and have GPU resources allocated
|
|
||||||
|
|
||||||
grad_weight = total_input.t().matmul(grad_output)
|
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
||||||
|
|
||||||
if ctx.async_grad_reduce_scatter:
|
|
||||||
handle.wait()
|
|
||||||
|
|
||||||
else:
|
|
||||||
world_size = dist.get_world_size(process_group)
|
|
||||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
||||||
|
|
||||||
# do all gather in is async way
|
|
||||||
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
|
||||||
# calculate gradient and prepare data asynchronously with all-gather
|
|
||||||
# calculate
|
|
||||||
grad_input = grad_output.matmul(weight.T)
|
|
||||||
grad_output = grad_output.contiguous()
|
|
||||||
# Convert the tensor shapes to 2D for execution compatibility
|
|
||||||
if len(grad_output.shape) > 2:
|
|
||||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
||||||
# prepare data
|
|
||||||
input_list = [
|
input_list = [
|
||||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||||
]
|
]
|
||||||
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
|
||||||
# wait until all-gather finished
|
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
gather_handle.wait()
|
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
|
# all-reduce scheduled first and have GPU resources allocated
|
||||||
|
|
||||||
# do reduce-scatter in async way
|
grad_weight = total_input.t().matmul(grad_output)
|
||||||
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
|
||||||
# calculate gradient
|
|
||||||
if len(input_parallel.shape) > 2:
|
|
||||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
|
||||||
grad_weight = input_parallel.t().matmul(grad_output)
|
|
||||||
# wait until reduce-scatter finished
|
|
||||||
reducescatter_handle.wait()
|
|
||||||
|
|
||||||
return output, grad_weight, grad_bias, None, None, None, None, None, None
|
if ctx.async_grad_reduce_scatter:
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
|
@ -988,7 +911,7 @@ class _AllToAll(torch.autograd.Function):
|
||||||
ctx.gather_dim = gather_dim
|
ctx.gather_dim = gather_dim
|
||||||
ctx.fp8_communication = fp8_communication
|
ctx.fp8_communication = fp8_communication
|
||||||
world_size = dist.get_world_size(process_group)
|
world_size = dist.get_world_size(process_group)
|
||||||
bsz, _, _ = input_.shape
|
bsz = input_.shape[0]
|
||||||
|
|
||||||
# using all_to_all_single when batch size is 1
|
# using all_to_all_single when batch size is 1
|
||||||
if bsz == 1:
|
if bsz == 1:
|
||||||
|
@ -1019,7 +942,7 @@ class _AllToAll(torch.autograd.Function):
|
||||||
gather_dim = ctx.scatter_dim
|
gather_dim = ctx.scatter_dim
|
||||||
fp8_communication = ctx.fp8_communication
|
fp8_communication = ctx.fp8_communication
|
||||||
world_size = dist.get_world_size(process_group)
|
world_size = dist.get_world_size(process_group)
|
||||||
bsz, _, _ = grad_output.shape
|
bsz = grad_output.shape[0]
|
||||||
|
|
||||||
if bsz == 1:
|
if bsz == 1:
|
||||||
return_grad = _all_to_all_single(
|
return_grad = _all_to_all_single(
|
||||||
|
@ -1204,10 +1127,10 @@ def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=F
|
||||||
|
|
||||||
|
|
||||||
def linear_gather_forward_reducescatter_backward(
|
def linear_gather_forward_reducescatter_backward(
|
||||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
|
||||||
):
|
):
|
||||||
return _LinearWithGatherForwardReduceScatterBackward.apply(
|
return _LinearWithGatherForwardReduceScatterBackward.apply(
|
||||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1224,10 +1147,10 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
|
||||||
|
|
||||||
|
|
||||||
def matmul_gather_forward_reducescatter_backward(
|
def matmul_gather_forward_reducescatter_backward(
|
||||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
|
||||||
):
|
):
|
||||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(
|
return _MatmulWithGatherForwardReduceScatterBackward.apply(
|
||||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -422,16 +422,21 @@ class RingAttention(torch.autograd.Function):
|
||||||
ATTN_DONE: torch.cuda.Event = None
|
ATTN_DONE: torch.cuda.Event = None
|
||||||
SP_STREAM: torch.cuda.Stream = None
|
SP_STREAM: torch.cuda.Stream = None
|
||||||
SP_GROUP: dist.ProcessGroup = None
|
SP_GROUP: dist.ProcessGroup = None
|
||||||
# duplicate process group for concurrent NCCL streams
|
|
||||||
# while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
|
# NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput,
|
||||||
# against this, in practice it seems to work fine.
|
# both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
|
||||||
|
# LoongTrain's original double ring impl. uses concurrent PGs
|
||||||
|
# (https://github.com/InternLM/InternEvo/blob/e52f2ffc9acf818e8f2b1f97dfc69ceb2f06e154/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L192)
|
||||||
|
# but I confirmed with Pytorch developers this can cause obscure "Software caused connection abort" errors.
|
||||||
|
# (https://github.com/pytorch/pytorch/issues/132852)
|
||||||
|
# NOTE: In general, a smarter idea is put as many P2P calls as possible into one `batch_isend_irecv`.
|
||||||
INNER_RING_GROUP: dist.ProcessGroup = None
|
INNER_RING_GROUP: dist.ProcessGroup = None
|
||||||
INNER_RING_GROUP_COPY: dist.ProcessGroup = None
|
# INNER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||||
INTER_RING_GROUP: dist.ProcessGroup = None
|
INTER_RING_GROUP: dist.ProcessGroup = None
|
||||||
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
|
# INTER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_double_ring_groups(sp_group, inner_ring_size=None):
|
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
|
||||||
"""
|
"""
|
||||||
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
|
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
|
||||||
shouldn't be larger than the number of NICs on each node.
|
shouldn't be larger than the number of NICs on each node.
|
||||||
|
@ -441,21 +446,17 @@ class RingAttention(torch.autograd.Function):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
|
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
|
||||||
"""
|
"""
|
||||||
|
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
|
||||||
|
|
||||||
|
sp_group = pg_mesh.get_group_along_axis(sp_axis)
|
||||||
sp_size = dist.get_world_size(sp_group)
|
sp_size = dist.get_world_size(sp_group)
|
||||||
sp_rank = dist.get_rank(sp_group)
|
sp_rank = dist.get_rank(sp_group)
|
||||||
|
|
||||||
if inner_ring_size is None:
|
assert inner_ring_size is not None
|
||||||
if torch.cuda.device_count() >= dist.get_world_size():
|
|
||||||
# single node, no need to consider NICs
|
assert (
|
||||||
return sp_group, sp_group
|
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
|
||||||
if sp_size <= 4:
|
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
||||||
inner_ring_size = min(2, sp_size)
|
|
||||||
else:
|
|
||||||
inner_ring_size = min(4, sp_size)
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
|
|
||||||
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
|
||||||
|
|
||||||
if inner_ring_size == sp_size:
|
if inner_ring_size == sp_size:
|
||||||
return sp_group, sp_group
|
return sp_group, sp_group
|
||||||
|
@ -474,14 +475,14 @@ class RingAttention(torch.autograd.Function):
|
||||||
# Create inner ring groups
|
# Create inner ring groups
|
||||||
for i in range(inner_ring_size):
|
for i in range(inner_ring_size):
|
||||||
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
|
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
|
||||||
group = dist.new_group(ranks)
|
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
|
||||||
if sp_rank in ranks:
|
if sp_rank in ranks:
|
||||||
inner_ring_group = group
|
inner_ring_group = group
|
||||||
|
|
||||||
# Create inter ring groups
|
# Create inter ring groups
|
||||||
for i in range(num_rings):
|
for i in range(num_rings):
|
||||||
ranks = list(range(i, sp_size, num_rings))
|
ranks = list(range(i, sp_size, num_rings))
|
||||||
group = dist.new_group(ranks)
|
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
|
||||||
if sp_rank in ranks:
|
if sp_rank in ranks:
|
||||||
inter_ring_group = group
|
inter_ring_group = group
|
||||||
|
|
||||||
|
@ -492,7 +493,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
q, # (B, H, Sq, D)
|
q, # (B, H, Sq, D)
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
sp_group,
|
sp_axis,
|
||||||
attention_mask_type,
|
attention_mask_type,
|
||||||
cu_seqlens=None,
|
cu_seqlens=None,
|
||||||
max_seqlen=None,
|
max_seqlen=None,
|
||||||
|
@ -502,6 +503,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
deterministic=False,
|
deterministic=False,
|
||||||
return_softmax=False,
|
return_softmax=False,
|
||||||
inner_ring_size=None,
|
inner_ring_size=None,
|
||||||
|
pg_mesh=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -512,7 +514,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
|
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
|
||||||
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
|
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
|
||||||
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
|
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
|
||||||
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism
|
sp_axis (Optional[int]): Sp axis for the global pg mesh.
|
||||||
sp_tream (torch.cuda.Stream): An different stream for output correction.
|
sp_tream (torch.cuda.Stream): An different stream for output correction.
|
||||||
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
||||||
of the sequences in the batch, used to index into q.
|
of the sequences in the batch, used to index into q.
|
||||||
|
@ -537,7 +539,6 @@ class RingAttention(torch.autograd.Function):
|
||||||
RingAttention.ATTN_DONE = torch.cuda.Event()
|
RingAttention.ATTN_DONE = torch.cuda.Event()
|
||||||
if RingAttention.SP_STREAM is None:
|
if RingAttention.SP_STREAM is None:
|
||||||
RingAttention.SP_STREAM = torch.cuda.Stream()
|
RingAttention.SP_STREAM = torch.cuda.Stream()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
q.shape[2] == k.shape[2]
|
q.shape[2] == k.shape[2]
|
||||||
), "Q, K and V having different sequence lengths (inference or cross-attn)\
|
), "Q, K and V having different sequence lengths (inference or cross-attn)\
|
||||||
|
@ -546,11 +547,13 @@ class RingAttention(torch.autograd.Function):
|
||||||
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
|
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
|
||||||
), f"Mask type {attention_mask_type} is not supported yet."
|
), f"Mask type {attention_mask_type} is not supported yet."
|
||||||
|
|
||||||
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
|
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
|
||||||
|
|
||||||
if RingAttention.SP_GROUP is not sp_group:
|
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
|
||||||
|
sp_group = pg_mesh.get_group_along_axis(sp_axis)
|
||||||
|
if inner_ring_size != None:
|
||||||
RingAttention.SP_GROUP = sp_group
|
RingAttention.SP_GROUP = sp_group
|
||||||
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
|
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size)
|
||||||
RingAttention.INNER_RING_GROUP = inner_ring_group
|
RingAttention.INNER_RING_GROUP = inner_ring_group
|
||||||
RingAttention.INTER_RING_GROUP = inter_ring_group
|
RingAttention.INTER_RING_GROUP = inter_ring_group
|
||||||
else:
|
else:
|
||||||
|
@ -628,7 +631,13 @@ class RingAttention(torch.autograd.Function):
|
||||||
inner_ring_group: Optional[dist.ProcessGroup] = None,
|
inner_ring_group: Optional[dist.ProcessGroup] = None,
|
||||||
inter_ring_group: Optional[dist.ProcessGroup] = None,
|
inter_ring_group: Optional[dist.ProcessGroup] = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Forward supporting both packed (varlen) and batched(fixed length, no padding) sequences.
|
||||||
|
No separate version for batched seq (hard to maintain), which incurs
|
||||||
|
some overhead in sequence splitting due to python for loops.
|
||||||
|
Uses two CUDA streams to overlap softmax denominator correction with next flash attn
|
||||||
|
(see comments below).
|
||||||
|
"""
|
||||||
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
|
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
|
||||||
max_seqlen_q = max_seqlen_kv = max_seqlen
|
max_seqlen_q = max_seqlen_kv = max_seqlen
|
||||||
cu_seqlens_half = cu_seqlens // 2
|
cu_seqlens_half = cu_seqlens // 2
|
||||||
|
@ -670,7 +679,8 @@ class RingAttention(torch.autograd.Function):
|
||||||
|
|
||||||
sp_size = dist.get_world_size(sp_group)
|
sp_size = dist.get_world_size(sp_group)
|
||||||
sp_rank = dist.get_rank(sp_group)
|
sp_rank = dist.get_rank(sp_group)
|
||||||
# Attempt to achieve concurrent comm in the two-stream forward
|
|
||||||
|
# Create communicators corresponding to two CUDA streams
|
||||||
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
|
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
|
||||||
inter_ring_comm = RingComm(inter_ring_group)
|
inter_ring_comm = RingComm(inter_ring_group)
|
||||||
local_sp_size = dist.get_world_size(inner_ring_group)
|
local_sp_size = dist.get_world_size(inner_ring_group)
|
||||||
|
@ -678,7 +688,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
|
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
|
||||||
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
|
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
|
||||||
|
|
||||||
# Non-contiguous indexing copies to a new contiguous tensor,
|
# Any type of indexing(but not slicing) copies to a new contiguous tensor,
|
||||||
# so only do it once
|
# so only do it once
|
||||||
if sp_rank != sp_size - 1:
|
if sp_rank != sp_size - 1:
|
||||||
q1 = q[half_idx_back]
|
q1 = q[half_idx_back]
|
||||||
|
@ -695,6 +705,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
rng_states = [None for _ in range(sp_size)]
|
rng_states = [None for _ in range(sp_size)]
|
||||||
sp_streams = [torch.cuda.current_stream(), sp_stream]
|
sp_streams = [torch.cuda.current_stream(), sp_stream]
|
||||||
|
|
||||||
|
# Helper to pass args to FA
|
||||||
def _forward(q, k, v, causal):
|
def _forward(q, k, v, causal):
|
||||||
(
|
(
|
||||||
_,
|
_,
|
||||||
|
@ -725,6 +736,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
if i < local_sp_size - 1:
|
if i < local_sp_size - 1:
|
||||||
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||||
|
|
||||||
|
# Forward within a node
|
||||||
def _local_ring_forward():
|
def _local_ring_forward():
|
||||||
# (Hopefully) overlap output correction with next flash attn
|
# (Hopefully) overlap output correction with next flash attn
|
||||||
for i in range(local_sp_size):
|
for i in range(local_sp_size):
|
||||||
|
@ -733,6 +745,8 @@ class RingAttention(torch.autograd.Function):
|
||||||
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
||||||
if i > 0:
|
if i > 0:
|
||||||
local_kv_comms[(i + 1) % 2].wait()
|
local_kv_comms[(i + 1) % 2].wait()
|
||||||
|
|
||||||
|
# Prefetch
|
||||||
if i == 0:
|
if i == 0:
|
||||||
_kv_comm(i)
|
_kv_comm(i)
|
||||||
|
|
||||||
|
@ -766,15 +780,22 @@ class RingAttention(torch.autograd.Function):
|
||||||
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
||||||
RingAttention.ATTN_DONE.record()
|
RingAttention.ATTN_DONE.record()
|
||||||
# Pipeline the next KV comm with output correction instead of the next flash attn
|
# Pipeline the next KV comm with output correction instead of the next flash attn
|
||||||
# to minimize idle time when comm takes longer than attn.
|
# kernel, to minimize bubble when comm takes longer than attn.
|
||||||
_kv_comm(i + 1)
|
_kv_comm(i + 1)
|
||||||
|
|
||||||
block_softmax_lse[i % 2] = (
|
block_softmax_lse[i % 2] = (
|
||||||
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||||
) # (H, T) -> (T, H, 1)
|
) # (H, T) -> (T, H, 1)
|
||||||
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
|
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
|
||||||
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
|
|
||||||
# In reality this always finishes before next flash attn; no need for extra sync.
|
# Output and log sum exp correction.
|
||||||
|
# Ideally overlap this with the next flash attn kernel,
|
||||||
|
# since attn uses Tensor Core and rescale is element-wise, memory-bound and uses CUDA cores.
|
||||||
|
# (NOTE that this is the same as ping-pong scheduling idea in FA3)
|
||||||
|
# TODO However sometimes while the GPU has scheduled the next kernel,
|
||||||
|
# it's reluctant to launch it in overlap. Some potential causes:
|
||||||
|
# 1. need lower-level CUDA scheduling 2. further benchmark against Megatron-LM
|
||||||
|
# 3. register spilling by FA kernel.
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = block_out[0]
|
out = block_out[0]
|
||||||
softmax_lse = block_softmax_lse[0]
|
softmax_lse = block_softmax_lse[0]
|
||||||
|
@ -790,15 +811,17 @@ class RingAttention(torch.autograd.Function):
|
||||||
torch.cuda.current_stream().wait_stream(sp_stream)
|
torch.cuda.current_stream().wait_stream(sp_stream)
|
||||||
return out, softmax_lse
|
return out, softmax_lse
|
||||||
|
|
||||||
|
# Forward for inter-node (the outer ring in 2D ring)
|
||||||
def _other_ring_forward(ring_num_idx, out, softmax_lse):
|
def _other_ring_forward(ring_num_idx, out, softmax_lse):
|
||||||
# Loop through the inner ring after receiving
|
# Loop through the inner ring after receiving
|
||||||
# all new KVs from the previous inner ring
|
# all new KVs from another ring
|
||||||
for i in range(local_sp_size):
|
for i in range(local_sp_size):
|
||||||
with torch.cuda.stream(sp_streams[i % 2]):
|
with torch.cuda.stream(sp_streams[i % 2]):
|
||||||
# Send & recv KV
|
# Send & recv KV
|
||||||
if i > 0:
|
if i > 0:
|
||||||
local_kv_comms[(i + 1) % 2].wait()
|
local_kv_comms[(i + 1) % 2].wait()
|
||||||
|
|
||||||
|
# Prefetch
|
||||||
if i == 0:
|
if i == 0:
|
||||||
_kv_comm(i)
|
_kv_comm(i)
|
||||||
|
|
||||||
|
@ -895,7 +918,8 @@ class RingAttention(torch.autograd.Function):
|
||||||
def backward(ctx, dout, _):
|
def backward(ctx, dout, _):
|
||||||
"""
|
"""
|
||||||
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
|
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
|
||||||
over all ranks for accumulation.
|
over all ranks for accumulation. We avoid using two streams due to backward using doubled
|
||||||
|
buffers and more comm cost.
|
||||||
"""
|
"""
|
||||||
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
|
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
|
||||||
rng_states = ctx.saved_tensors[9:]
|
rng_states = ctx.saved_tensors[9:]
|
||||||
|
@ -927,7 +951,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
local_sp_rank = dist.get_rank(sp_group)
|
local_sp_rank = dist.get_rank(sp_group)
|
||||||
sp_size = dist.get_world_size(sp_group)
|
sp_size = dist.get_world_size(sp_group)
|
||||||
|
|
||||||
# Using separate streams (pg) for concurrent kv and dkv comm may
|
# NOTE: Using separate streams (PG) for concurrent kv and dkv comm may
|
||||||
# cause NCCL "software caused connection abort" here...
|
# cause NCCL "software caused connection abort" here...
|
||||||
local_kv_comm = RingComm(local_kv_group)
|
local_kv_comm = RingComm(local_kv_group)
|
||||||
local_dkv_comm = RingComm(local_kv_group)
|
local_dkv_comm = RingComm(local_kv_group)
|
||||||
|
@ -959,6 +983,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
|
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
|
||||||
del k, v
|
del k, v
|
||||||
|
|
||||||
|
# Helper to pass args to FA
|
||||||
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
|
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
|
||||||
_flash_attn_backward(
|
_flash_attn_backward(
|
||||||
dout,
|
dout,
|
||||||
|
@ -979,8 +1004,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
**misc_kwargs,
|
**misc_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: We avoid using two streams due to doubled buffers
|
# Backward within a node
|
||||||
# and that backward is more communication intensive.
|
|
||||||
def _local_ring_backward():
|
def _local_ring_backward():
|
||||||
for i in range(local_sp_size):
|
for i in range(local_sp_size):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
|
@ -1043,6 +1067,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
|
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
|
||||||
return dq, dkv_recv, dkv_send
|
return dq, dkv_recv, dkv_send
|
||||||
|
|
||||||
|
# Backward for inter-node (the outer ring in 2D ring)
|
||||||
def _other_ring_backward(ring_num_idx, dq):
|
def _other_ring_backward(ring_num_idx, dq):
|
||||||
if ring_num_idx > inter_ring_rank:
|
if ring_num_idx > inter_ring_rank:
|
||||||
# Indexing is expensive
|
# Indexing is expensive
|
||||||
|
@ -1127,34 +1152,34 @@ class RingAttention(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_varlen_batch(
|
def prepare_varlen_batch(
|
||||||
attention_mask: torch.Tensor,
|
padding_mask: torch.Tensor,
|
||||||
sp_group: dist.ProcessGroup,
|
sp_group: dist.ProcessGroup,
|
||||||
inputs_embeds: torch.Tensor = None,
|
inputs_embeds: torch.Tensor = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
is_label: bool = False,
|
is_label: bool = False,
|
||||||
is_2d: bool = True,
|
is_batched_seq: bool = True,
|
||||||
):
|
):
|
||||||
|
# TODO: support setting a batch dim (fix packing length) for packed mode, so that
|
||||||
|
# DP can be used (needs to modify dataloader too)
|
||||||
"""
|
"""
|
||||||
Preprocess a batch of padded sequence by splitting input sequence by sp_size
|
Preprocess a batch of padded sequence by splitting input sequence by sp_size
|
||||||
sequence-wise and packing them into one sequence. Updates the mask info accordingly.
|
seq-wise and packing them into one sequence. Updates the mask info accordingly.
|
||||||
Args:
|
Args:
|
||||||
attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
|
padding_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
|
||||||
sp_group (dist.ProcessGroup): Process group for sequence parallelism
|
sp_group (dist.ProcessGroup): Process group for sequence parallelism
|
||||||
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
|
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
|
||||||
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
|
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
|
||||||
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
|
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
|
||||||
token of each sequence.
|
token of each sequence.
|
||||||
is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten
|
is_batched_seq (bool, optional): If True, then the input is a batch of (potentially padded) sequences
|
||||||
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
|
of shape [B, Sq, ...]; else a packed sequence of shape [T, ...].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor:
|
inputs_embeds (torch.Tensor):
|
||||||
Packed input embeddings of shape [B, Sq // sp_size, ...].
|
Packed input embeddings of shape [B, Sq // sp_size, ...] if is_batched_seq, else [T, ...].
|
||||||
|
mask_info (Dict[str, Any]):
|
||||||
Dict[str, Any]:
|
|
||||||
A dictionary containing mask info.
|
A dictionary containing mask info.
|
||||||
|
position_ids (torch.Tensor):
|
||||||
torch.Tensor:
|
|
||||||
Packed position ids of shape [..., Sq // sp_size].
|
Packed position ids of shape [..., Sq // sp_size].
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -1162,12 +1187,11 @@ class RingAttention(torch.autograd.Function):
|
||||||
sp_size = dist.get_world_size(group=sp_group)
|
sp_size = dist.get_world_size(group=sp_group)
|
||||||
sp_rank = dist.get_rank(group=sp_group)
|
sp_rank = dist.get_rank(group=sp_group)
|
||||||
mask_info = {}
|
mask_info = {}
|
||||||
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False)
|
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(padding_mask, return_indices=False)
|
||||||
|
|
||||||
# Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size)
|
# Unpad, split seq-wise, then pad to (B, max_seqlen // sp_size)
|
||||||
# Split mask to compute local nonzero position indices
|
|
||||||
# (B, Sq) -> (B, max_seqlen // sp_size)
|
# (B, Sq) -> (B, max_seqlen // sp_size)
|
||||||
attention_mask = attention_mask[:, : mask_info["max_seqlen"]]
|
padding_mask = padding_mask[:, : mask_info["max_seqlen"]]
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
|
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
|
||||||
inputs_embeds = split_varlen_zigzag(
|
inputs_embeds = split_varlen_zigzag(
|
||||||
|
@ -1175,11 +1199,12 @@ class RingAttention(torch.autograd.Function):
|
||||||
mask_info["cu_seqlens"],
|
mask_info["cu_seqlens"],
|
||||||
sp_group,
|
sp_group,
|
||||||
mask_info["max_seqlen"],
|
mask_info["max_seqlen"],
|
||||||
is_2d=is_2d,
|
is_batched_seq=is_batched_seq,
|
||||||
is_label=is_label,
|
is_label=is_label,
|
||||||
)
|
)
|
||||||
attention_mask = split_varlen_zigzag(
|
# Split mask to get local nonzero seq positions
|
||||||
attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d
|
padding_mask = split_varlen_zigzag(
|
||||||
|
padding_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_batched_seq=is_batched_seq
|
||||||
)
|
)
|
||||||
|
|
||||||
if position_ids is not None:
|
if position_ids is not None:
|
||||||
|
@ -1192,7 +1217,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
)
|
)
|
||||||
|
|
||||||
mask_info["max_seqlen"] //= sp_size
|
mask_info["max_seqlen"] //= sp_size
|
||||||
mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
mask_info["valid_indices"] = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
||||||
mask_info["cu_seqlens"] //= sp_size
|
mask_info["cu_seqlens"] //= sp_size
|
||||||
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||||
return inputs_embeds, mask_info, position_ids
|
return inputs_embeds, mask_info, position_ids
|
||||||
|
|
|
@ -23,18 +23,16 @@ from colossalai.tensor.d_tensor.api import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._operation import (
|
from ._operation import (
|
||||||
gather_forward_reducescatter_backward,
|
|
||||||
gather_forward_split_backward,
|
gather_forward_split_backward,
|
||||||
linear_gather_forward_reducescatter_backward,
|
linear_gather_forward_reducescatter_backward,
|
||||||
linear_reducescatter_forward_gather_backward,
|
linear_reducescatter_forward_gather_backward,
|
||||||
linear_with_async_comm,
|
linear_with_async_comm,
|
||||||
linear_with_grad_accum,
|
linear_with_grad_accum,
|
||||||
reduce_forward,
|
reduce_forward,
|
||||||
reducescatter_forward_gather_backward,
|
|
||||||
split_forward_gather_backward,
|
split_forward_gather_backward,
|
||||||
)
|
)
|
||||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||||
from .utils import create_randomizer_with_offset
|
from .utils import create_randomizer_with_offset, is_share_sp_tp
|
||||||
|
|
||||||
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
|
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
|
||||||
|
|
||||||
|
@ -197,7 +195,6 @@ class Linear1D_Col(ParallelModule):
|
||||||
to all GPUs, otherwise, every GPU will have its output
|
to all GPUs, otherwise, every GPU will have its output
|
||||||
which is :math:`Y_i = XA_i`, defaults to False
|
which is :math:`Y_i = XA_i`, defaults to False
|
||||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||||
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
|
||||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
which is preserved for kernel fusion, defaults to False
|
which is preserved for kernel fusion, defaults to False
|
||||||
weight_initializer (`typing.Callable`):
|
weight_initializer (`typing.Callable`):
|
||||||
|
@ -220,7 +217,6 @@ class Linear1D_Col(ParallelModule):
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
seq_parallel_mode: str = None,
|
seq_parallel_mode: str = None,
|
||||||
seq_parallel_dim: int = 1,
|
seq_parallel_dim: int = 1,
|
||||||
overlap: torch.cuda.Stream = None,
|
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
bias_: Optional[Parameter] = None,
|
bias_: Optional[Parameter] = None,
|
||||||
|
@ -238,7 +234,6 @@ class Linear1D_Col(ParallelModule):
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
self.seq_parallel_mode = seq_parallel_mode
|
self.seq_parallel_mode = seq_parallel_mode
|
||||||
self.seq_parallel_dim = seq_parallel_dim
|
self.seq_parallel_dim = seq_parallel_dim
|
||||||
self.overlap = overlap
|
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
@ -345,22 +340,16 @@ class Linear1D_Col(ParallelModule):
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
if self.seq_parallel_mode == "split_gather":
|
|
||||||
input_parallel = gather_forward_reducescatter_backward(
|
if is_share_sp_tp(self.seq_parallel_mode):
|
||||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||||
)
|
|
||||||
output_parallel = linear_with_async_comm(
|
|
||||||
input_parallel,
|
input_parallel,
|
||||||
self.weight,
|
self.weight,
|
||||||
bias,
|
bias,
|
||||||
self.process_group,
|
self.process_group,
|
||||||
False,
|
True,
|
||||||
fp8_communication=self.fp8_communication,
|
self.seq_parallel_dim,
|
||||||
use_zbv=self.use_zbv,
|
ring=self.seq_parallel_mode == "ring",
|
||||||
)
|
|
||||||
elif self.seq_parallel_mode == "ring":
|
|
||||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
|
||||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_parallel = linear_with_async_comm(
|
output_parallel = linear_with_async_comm(
|
||||||
|
@ -584,31 +573,17 @@ class Linear1D_Row(ParallelModule):
|
||||||
handle.wait()
|
handle.wait()
|
||||||
output = torch.cat(output_parallel_list, dim=-1)
|
output = torch.cat(output_parallel_list, dim=-1)
|
||||||
else:
|
else:
|
||||||
if self.seq_parallel_mode is None:
|
if is_share_sp_tp(self.seq_parallel_mode):
|
||||||
output_parallel = linear_with_async_comm(
|
|
||||||
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
|
|
||||||
)
|
|
||||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
|
||||||
output_parallel = linear_with_async_comm(
|
|
||||||
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
|
|
||||||
)
|
|
||||||
output = reducescatter_forward_gather_backward(
|
|
||||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
|
||||||
)
|
|
||||||
elif self.seq_parallel_mode == "ring":
|
|
||||||
output = linear_reducescatter_forward_gather_backward(
|
output = linear_reducescatter_forward_gather_backward(
|
||||||
input_,
|
input_,
|
||||||
self.weight,
|
self.weight,
|
||||||
process_group=self.process_group,
|
process_group=self.process_group,
|
||||||
dim=self.seq_parallel_dim,
|
dim=self.seq_parallel_dim,
|
||||||
ring=True,
|
ring=self.seq_parallel_mode == "ring",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_parallel = linear_with_async_comm(
|
output_parallel = F.linear(input_, self.weight)
|
||||||
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
|
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||||
)
|
|
||||||
output = reduce_forward(output_parallel, self.process_group)
|
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
|
@ -716,7 +691,6 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
||||||
to all GPUs, otherwise, every GPU will have its output
|
to all GPUs, otherwise, every GPU will have its output
|
||||||
which is :math:`Y_i = XA_i`, defaults to False
|
which is :math:`Y_i = XA_i`, defaults to False
|
||||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||||
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
|
||||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
which is preserved for kernel fusion, defaults to False
|
which is preserved for kernel fusion, defaults to False
|
||||||
weight_initializer (`typing.Callable`):
|
weight_initializer (`typing.Callable`):
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
@ -24,17 +25,17 @@ from colossalai.tensor.d_tensor.api import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._operation import (
|
from ._operation import (
|
||||||
gather_forward_split_backward,
|
linear_gather_forward_reducescatter_backward,
|
||||||
|
linear_reducescatter_forward_gather_backward,
|
||||||
linear_with_async_comm,
|
linear_with_async_comm,
|
||||||
matmul_gather_forward_reducescatter_backward,
|
matmul_gather_forward_reducescatter_backward,
|
||||||
matmul_with_async_comm,
|
matmul_with_async_comm,
|
||||||
reduce_backward,
|
|
||||||
reduce_forward,
|
reduce_forward,
|
||||||
reducescatter_forward_gather_backward,
|
reducescatter_forward_gather_backward,
|
||||||
split_forward_gather_backward,
|
split_forward_gather_backward,
|
||||||
)
|
)
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
from .utils import create_randomizer_with_offset
|
from .utils import create_randomizer_with_offset, is_share_sp_tp
|
||||||
|
|
||||||
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
|
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
|
||||||
|
|
||||||
|
@ -44,21 +45,25 @@ __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col"
|
||||||
|
|
||||||
|
|
||||||
def split_fused_qkv_in_gpt2_style(
|
def split_fused_qkv_in_gpt2_style(
|
||||||
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
|
qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
|
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
qkv (torch.Tensor): The fused qkv tensor.
|
qkv (torch.Tensor): The fused qkv tensor.
|
||||||
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
|
split_sizes (List[int]): The sizes of the split tensor.
|
||||||
process_group (ProcessGroup): The process group for distributed communication.
|
process_group (ProcessGroup): The process group for distributed communication.
|
||||||
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
||||||
"""
|
"""
|
||||||
# get the number of slice for the fused qkv
|
# get the number of slice for the fused qkv
|
||||||
rank = dist.get_rank(group=process_group)
|
rank = dist.get_rank(group=process_group)
|
||||||
world_size = dist.get_world_size(group=process_group)
|
world_size = dist.get_world_size(group=process_group)
|
||||||
order = torch.arange(world_size * n_fused)
|
order = torch.arange(world_size * len(split_sizes))
|
||||||
|
new_split_sizes = []
|
||||||
|
for sz in split_sizes:
|
||||||
|
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
|
||||||
|
new_split_sizes.extend([sz // world_size] * world_size)
|
||||||
|
|
||||||
# split the fused qkv
|
# split the fused qkv
|
||||||
# from
|
# from
|
||||||
|
@ -66,9 +71,9 @@ def split_fused_qkv_in_gpt2_style(
|
||||||
# to
|
# to
|
||||||
# [Q1, Q2, K1, K2, V1, V2]
|
# [Q1, Q2, K1, K2, V1, V2]
|
||||||
if is_transposed:
|
if is_transposed:
|
||||||
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
|
weight_chunks = torch.split(qkv, new_split_sizes, dim=-1)
|
||||||
else:
|
else:
|
||||||
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0)
|
weight_chunks = torch.split(qkv, new_split_sizes, dim=0)
|
||||||
|
|
||||||
# rearrange the slice into the final order
|
# rearrange the slice into the final order
|
||||||
# from
|
# from
|
||||||
|
@ -85,18 +90,23 @@ def split_fused_qkv_in_gpt2_style(
|
||||||
|
|
||||||
|
|
||||||
def gather_fused_qkv_in_gpt2_style(
|
def gather_fused_qkv_in_gpt2_style(
|
||||||
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
|
qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
|
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
qkv (torch.Tensor): The fused qkv tensor.
|
qkv (torch.Tensor): The fused qkv tensor.
|
||||||
n_fused (int): The number items fused together, defaults to 3 (query, key and value).
|
split_sizes (List[int]): The sizes of the split tensor.
|
||||||
process_group (ProcessGroup): The process group for distributed communication.
|
process_group (ProcessGroup): The process group for distributed communication.
|
||||||
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
|
||||||
"""
|
"""
|
||||||
world_size = dist.get_world_size(group=process_group)
|
world_size = dist.get_world_size(group=process_group)
|
||||||
|
new_split_sizes = []
|
||||||
|
for sz in split_sizes:
|
||||||
|
assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
|
||||||
|
new_split_sizes.append(sz // world_size)
|
||||||
|
new_split_sizes = new_split_sizes * world_size
|
||||||
|
|
||||||
# gather the tensors
|
# gather the tensors
|
||||||
# from
|
# from
|
||||||
|
@ -121,13 +131,13 @@ def gather_fused_qkv_in_gpt2_style(
|
||||||
# to
|
# to
|
||||||
# [Q1, Q2, K1, K2, V1, V2]
|
# [Q1, Q2, K1, K2, V1, V2]
|
||||||
if is_transposed:
|
if is_transposed:
|
||||||
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
|
weight_chunks = torch.split(gather_weight, new_split_sizes, dim=-1)
|
||||||
else:
|
else:
|
||||||
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0)
|
weight_chunks = torch.split(gather_weight, new_split_sizes, dim=0)
|
||||||
|
|
||||||
reordered_chunk_list = []
|
reordered_chunk_list = []
|
||||||
for i in range(n_fused):
|
for i in range(len(split_sizes)):
|
||||||
reordered_chunk_list.extend(weight_chunks[i::n_fused])
|
reordered_chunk_list.extend(weight_chunks[i :: len(split_sizes)])
|
||||||
|
|
||||||
if is_transposed:
|
if is_transposed:
|
||||||
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
|
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
|
||||||
|
@ -136,6 +146,42 @@ def gather_fused_qkv_in_gpt2_style(
|
||||||
return reordered_gather_weight
|
return reordered_gather_weight
|
||||||
|
|
||||||
|
|
||||||
|
class _SplitForwardGatherBackwardFusedQKV(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||||
|
ctx.split_sizes = split_sizes
|
||||||
|
ctx.process_group = process_group
|
||||||
|
return split_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_output = gather_fused_qkv_in_gpt2_style(
|
||||||
|
grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True
|
||||||
|
)
|
||||||
|
return grad_output, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def split_forward_gather_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||||
|
return _SplitForwardGatherBackwardFusedQKV.apply(qkv, split_sizes, process_group)
|
||||||
|
|
||||||
|
|
||||||
|
class _GatherForwardSplitBackwardFusedQKV(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||||
|
ctx.split_sizes = split_sizes
|
||||||
|
ctx.process_group = process_group
|
||||||
|
return gather_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_output = split_fused_qkv_in_gpt2_style(grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True)
|
||||||
|
return grad_output, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def gather_forward_split_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
|
||||||
|
return _GatherForwardSplitBackwardFusedQKV.apply(qkv, split_sizes, process_group)
|
||||||
|
|
||||||
|
|
||||||
class GPT2FusedLinearConv1D_Col(ParallelModule):
|
class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
r"""Linear layer with column parallelism.
|
r"""Linear layer with column parallelism.
|
||||||
|
|
||||||
|
@ -145,10 +191,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
Args:
|
Args:
|
||||||
in_features (int): size of each input sample.
|
in_features (int): size of each input sample.
|
||||||
out_features (int): size of each output sample.
|
out_features (int): size of each output sample.
|
||||||
|
split_sizes (List[int]): The sizes of the split tensor.
|
||||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
device (`torch.device`): The device of parameters, defaults to None.
|
device (`torch.device`): The device of parameters, defaults to None.
|
||||||
n_fused (int): The number items fused, defaults to 3 (QKV).
|
|
||||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
|
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
|
||||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||||
|
@ -169,16 +215,14 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
self,
|
self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
|
split_sizes: List[int],
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
async_communication: bool = False,
|
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
seq_parallel_mode: str = None,
|
seq_parallel_mode: str = None,
|
||||||
overlap: bool = False,
|
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
n_fused: int = 3,
|
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
bias_: Optional[Parameter] = None,
|
bias_: Optional[Parameter] = None,
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
@ -192,14 +236,16 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
self.seq_parallel_mode = seq_parallel_mode
|
self.seq_parallel_mode = seq_parallel_mode
|
||||||
self.overlap = overlap
|
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
self.n_fused = n_fused
|
self.split_sizes = split_sizes
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.async_communication = async_communication
|
|
||||||
self.fp8_communication = fp8_communication
|
self.fp8_communication = fp8_communication
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(split_sizes) == out_features
|
||||||
|
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
raise ValueError("cannot skip bias addition if bias is None")
|
raise ValueError("cannot skip bias addition if bias is None")
|
||||||
|
|
||||||
|
@ -223,10 +269,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
|
||||||
def shard_fn(tensor):
|
def shard_fn(tensor):
|
||||||
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
|
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||||
|
|
||||||
def gather_fn(tensor):
|
def gather_fn(tensor):
|
||||||
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
|
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||||
|
|
||||||
if not is_customized_distributed_tensor(self.weight):
|
if not is_customized_distributed_tensor(self.weight):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -252,7 +298,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
module: nn.Module,
|
||||||
|
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||||
|
split_sizes: List[int],
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> ParallelModule:
|
) -> ParallelModule:
|
||||||
r"""
|
r"""
|
||||||
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
|
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
|
||||||
|
@ -260,7 +310,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
Args:
|
Args:
|
||||||
module (`nn.Linear`): The module to be converted.
|
module (`nn.Linear`): The module to be converted.
|
||||||
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
||||||
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
|
split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
|
||||||
"""
|
"""
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes
|
# get the attributes
|
||||||
|
@ -291,6 +341,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
weight=module.weight,
|
weight=module.weight,
|
||||||
bias_=module.bias,
|
bias_=module.bias,
|
||||||
|
split_sizes=split_sizes,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -313,7 +364,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
if self.seq_parallel_mode == "split_gather":
|
if is_share_sp_tp(self.seq_parallel_mode):
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||||
input_parallel,
|
input_parallel,
|
||||||
|
@ -322,31 +373,18 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
self.process_group,
|
self.process_group,
|
||||||
True,
|
True,
|
||||||
1,
|
1,
|
||||||
self.overlap,
|
ring=self.seq_parallel_mode == "ring",
|
||||||
fp8_communication=self.fp8_communication,
|
|
||||||
)
|
|
||||||
elif self.seq_parallel_mode == "ring":
|
|
||||||
input_parallel = input_
|
|
||||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
|
||||||
input_parallel,
|
|
||||||
self.weight,
|
|
||||||
bias,
|
|
||||||
self.process_group,
|
|
||||||
True,
|
|
||||||
1,
|
|
||||||
self.overlap,
|
|
||||||
True,
|
|
||||||
fp8_communication=self.fp8_communication,
|
fp8_communication=self.fp8_communication,
|
||||||
)
|
)
|
||||||
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||||
# Set up backprop all-reduce.
|
# Set up backprop all-reduce.
|
||||||
input_parallel = reduce_backward(input_, self.process_group)
|
input_parallel = input_
|
||||||
output_parallel = matmul_with_async_comm(
|
output_parallel = matmul_with_async_comm(
|
||||||
input_parallel,
|
input_parallel,
|
||||||
self.weight,
|
self.weight,
|
||||||
bias,
|
bias,
|
||||||
self.process_group,
|
self.process_group,
|
||||||
self.async_communication,
|
True,
|
||||||
fp8_communication=self.fp8_communication,
|
fp8_communication=self.fp8_communication,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -354,9 +392,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
output = gather_forward_split_backward(
|
output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
|
||||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
|
|
||||||
|
@ -565,7 +601,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||||
output_parallel = torch.matmul(input_, self.weight)
|
output_parallel = torch.matmul(input_, self.weight)
|
||||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
elif is_share_sp_tp(self.seq_parallel_mode):
|
||||||
output_parallel = torch.matmul(input_, self.weight)
|
output_parallel = torch.matmul(input_, self.weight)
|
||||||
output = reducescatter_forward_gather_backward(
|
output = reducescatter_forward_gather_backward(
|
||||||
output_parallel,
|
output_parallel,
|
||||||
|
@ -573,13 +609,6 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
1,
|
1,
|
||||||
self.fp8_communication,
|
self.fp8_communication,
|
||||||
)
|
)
|
||||||
elif self.seq_parallel_mode == "ring":
|
|
||||||
output_parallel = torch.matmul(input_, self.weight)
|
|
||||||
output = reducescatter_forward_gather_backward(
|
|
||||||
output_parallel,
|
|
||||||
self.process_group,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
|
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
|
||||||
|
|
||||||
|
@ -605,10 +634,10 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
Args:
|
Args:
|
||||||
in_features (int): size of each input sample.
|
in_features (int): size of each input sample.
|
||||||
out_features (int): size of each output sample.
|
out_features (int): size of each output sample.
|
||||||
|
split_sizes (List[int]): The sizes of the split tensor.
|
||||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
device (`torch.device`): The device of parameters, defaults to None.
|
device (`torch.device`): The device of parameters, defaults to None.
|
||||||
n_fused (int): The number items fused, defaults to 3 (QKV).
|
|
||||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||||
to all GPUs, otherwise, every GPU will have its output
|
to all GPUs, otherwise, every GPU will have its output
|
||||||
|
@ -628,14 +657,15 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
self,
|
self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
|
split_sizes: List[int],
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
async_communication: bool = False,
|
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
|
seq_parallel_mode: str = None,
|
||||||
|
seq_parallel_dim: int = 1,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
n_fused: int = 3,
|
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
bias_: Optional[Parameter] = None,
|
bias_: Optional[Parameter] = None,
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
@ -647,13 +677,18 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
|
self.seq_parallel_mode = seq_parallel_mode
|
||||||
|
self.seq_parallel_dim = seq_parallel_dim
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
self.n_fused = n_fused
|
self.split_sizes = split_sizes
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.async_communication = async_communication
|
|
||||||
self.fp8_communication = fp8_communication
|
self.fp8_communication = fp8_communication
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(split_sizes) == out_features
|
||||||
|
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
raise ValueError("cannot skip bias addition if bias is None")
|
raise ValueError("cannot skip bias addition if bias is None")
|
||||||
|
|
||||||
|
@ -677,10 +712,10 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
|
||||||
def shard_fn(tensor):
|
def shard_fn(tensor):
|
||||||
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
|
||||||
|
|
||||||
def gather_fn(tensor):
|
def gather_fn(tensor):
|
||||||
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
|
||||||
|
|
||||||
if not is_customized_distributed_tensor(self.weight):
|
if not is_customized_distributed_tensor(self.weight):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -706,7 +741,11 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs
|
module: nn.Module,
|
||||||
|
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||||
|
split_sizes: List[int],
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> ParallelModule:
|
) -> ParallelModule:
|
||||||
r"""
|
r"""
|
||||||
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
|
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
|
||||||
|
@ -714,7 +753,7 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
Args:
|
Args:
|
||||||
module (`nn.Linear`): The module to be converted.
|
module (`nn.Linear`): The module to be converted.
|
||||||
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
||||||
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
|
split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
|
||||||
"""
|
"""
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
|
|
||||||
|
@ -737,25 +776,11 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
weight=module.weight,
|
weight=module.weight,
|
||||||
bias_=module.bias,
|
bias_=module.bias,
|
||||||
n_fused=n_fused,
|
split_sizes=split_sizes,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# # TODO: copy the sharded weights
|
|
||||||
# with torch.no_grad():
|
|
||||||
# sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
|
|
||||||
# n_fused=n_fused,
|
|
||||||
# process_group=process_group,
|
|
||||||
# is_transposed=False)
|
|
||||||
# linear_1d.weight.data.copy_(sharded_weight.data)
|
|
||||||
|
|
||||||
# if bias:
|
|
||||||
# sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
|
|
||||||
# n_fused=n_fused,
|
|
||||||
# process_group=process_group,
|
|
||||||
# is_transposed=False)
|
|
||||||
# linear_1d.bias.data.copy_(sharded_bias.data)
|
|
||||||
return linear_1d
|
return linear_1d
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
|
@ -772,19 +797,29 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||||
)
|
)
|
||||||
# Set up backprop all-reduce.
|
# Set up backprop all-reduce.
|
||||||
# input_parallel = reduce_backward(input_, self.process_group)
|
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
if is_share_sp_tp(self.seq_parallel_mode):
|
||||||
|
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||||
|
input_parallel,
|
||||||
|
self.weight,
|
||||||
|
bias,
|
||||||
|
self.process_group,
|
||||||
|
True,
|
||||||
|
self.seq_parallel_dim,
|
||||||
|
ring=self.seq_parallel_mode == "ring",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_parallel = linear_with_async_comm(
|
||||||
|
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
output = gather_forward_split_backward(
|
output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
|
||||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
|
|
||||||
|
@ -792,3 +827,196 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
return output, self.bias
|
return output, self.bias
|
||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class FusedLinear1D_Row(ParallelModule):
|
||||||
|
r"""Linear layer with row parallelism
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): size of each input sample.
|
||||||
|
out_features (int): size of each output sample.
|
||||||
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
|
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||||
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
|
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
|
||||||
|
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
|
||||||
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
|
which is preserved for kernel fusion, defaults to False
|
||||||
|
weight_initializer (:class:`typing.Callable`, optional):
|
||||||
|
The initializer of weight, defaults to kaiming uniform initializer.
|
||||||
|
bias_initializer (:class:`typing.Callable`, optional):
|
||||||
|
The initializer of bias, defaults to xavier uniform initializer.
|
||||||
|
|
||||||
|
More details about ``initializer`` please refer to
|
||||||
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
split_sizes: List[int],
|
||||||
|
bias: bool = True,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
device: torch.device = None,
|
||||||
|
process_group: ProcessGroup = None,
|
||||||
|
seq_parallel_mode: str = None,
|
||||||
|
seq_parallel_dim: int = 1,
|
||||||
|
parallel_input: bool = True,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
weight: Optional[Parameter] = None,
|
||||||
|
bias_: Optional[Parameter] = None,
|
||||||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
|
fp8_communication: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Keep input parameters
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.split_sizes = split_sizes
|
||||||
|
self.parallel_input = parallel_input
|
||||||
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.process_group = process_group
|
||||||
|
self.seq_parallel_mode = seq_parallel_mode
|
||||||
|
self.seq_parallel_dim = seq_parallel_dim
|
||||||
|
self.num_partitions = dist.get_world_size(self.process_group)
|
||||||
|
self.fp8_communication = fp8_communication
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(split_sizes) == in_features
|
||||||
|
), f"The sum of split_sizes({sum(split_sizes)}) should be equal to in_features({in_features})."
|
||||||
|
|
||||||
|
if skip_bias_add and not bias:
|
||||||
|
raise ValueError("cannot skip bias addition if bias is None")
|
||||||
|
|
||||||
|
# offset the seed with randomizer index and rank
|
||||||
|
seed = torch.random.initial_seed()
|
||||||
|
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||||
|
|
||||||
|
# sanity check
|
||||||
|
if weight is not None:
|
||||||
|
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||||
|
else:
|
||||||
|
assert bias_ is None, "bias_ must be None if weight is None"
|
||||||
|
|
||||||
|
# Parameters.
|
||||||
|
if weight is None:
|
||||||
|
# Initialize weight.
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
def shard_fn(tensor):
|
||||||
|
return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||||
|
|
||||||
|
def gather_fn(tensor):
|
||||||
|
return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
|
||||||
|
|
||||||
|
if not is_customized_distributed_tensor(self.weight):
|
||||||
|
with torch.no_grad():
|
||||||
|
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
|
||||||
|
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
if bias_ is None:
|
||||||
|
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||||
|
self.bias = bias_
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
if weight is None:
|
||||||
|
with self.randomizer.fork_rng(enable_cpu=True):
|
||||||
|
self.reset_parameters(weight_initializer, bias_initializer)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(
|
||||||
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], split_sizes: List[int], **kwargs
|
||||||
|
) -> ParallelModule:
|
||||||
|
r"""
|
||||||
|
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||||
|
"""
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
# get the attributes
|
||||||
|
in_features = module.in_features
|
||||||
|
out_features = module.out_features
|
||||||
|
bias = module.bias is not None
|
||||||
|
device = module.weight.device
|
||||||
|
|
||||||
|
# ensure only one process group is passed
|
||||||
|
if isinstance(process_group, (list, tuple)):
|
||||||
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
|
process_group = process_group[0]
|
||||||
|
|
||||||
|
linear_1d = FusedLinear1D_Row(
|
||||||
|
in_features=in_features,
|
||||||
|
out_features=out_features,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
process_group=process_group,
|
||||||
|
weight=module.weight,
|
||||||
|
bias_=module.bias,
|
||||||
|
split_sizes=split_sizes,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return linear_1d
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
|
fan_in, fan_out = self.in_features, self.out_features
|
||||||
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
bias_initializer(self.bias, fan_in=fan_in)
|
||||||
|
if self.process_group is None:
|
||||||
|
src_rank = 0
|
||||||
|
else:
|
||||||
|
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
||||||
|
|
||||||
|
origin_device = self.bias.device
|
||||||
|
bias = self.bias.cuda()
|
||||||
|
dist.broadcast(bias, src=src_rank, group=self.process_group)
|
||||||
|
bias = bias.to(origin_device)
|
||||||
|
self.bias.copy_(bias)
|
||||||
|
|
||||||
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
|
# Set up backprop all-reduce.
|
||||||
|
if self.parallel_input:
|
||||||
|
assert (
|
||||||
|
input_.shape[-1] == self.weight.shape[-1]
|
||||||
|
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||||
|
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||||
|
)
|
||||||
|
input_ = input_
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
|
||||||
|
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||||
|
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
||||||
|
)
|
||||||
|
input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
|
||||||
|
|
||||||
|
if is_share_sp_tp(self.seq_parallel_mode):
|
||||||
|
output = linear_reducescatter_forward_gather_backward(
|
||||||
|
input_,
|
||||||
|
self.weight,
|
||||||
|
process_group=self.process_group,
|
||||||
|
dim=self.seq_parallel_dim,
|
||||||
|
ring=self.seq_parallel_mode == "ring",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_parallel = F.linear(input_, self.weight)
|
||||||
|
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||||
|
|
||||||
|
if not self.skip_bias_add:
|
||||||
|
if self.bias is not None:
|
||||||
|
output = output + self.bias
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
return output, self.bias
|
||||||
|
|
|
@ -295,8 +295,8 @@ def split_batch_zigzag(
|
||||||
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
|
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
|
||||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
|
Split the input sequence batch . Naively spliting the attention mask in the causal setting
|
||||||
in the causal setting will result in the preceding ranks having much less workload.
|
will result in the preceding ranks having much less workload.
|
||||||
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
|
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
|
||||||
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
|
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
|
||||||
|
|
||||||
|
@ -346,40 +346,42 @@ def split_varlen_zigzag(
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
sp_group: ProcessGroup,
|
sp_group: ProcessGroup,
|
||||||
max_seqlen: int = 0,
|
max_seqlen: int = 0,
|
||||||
is_2d: bool = False,
|
is_batched_seq: bool = False,
|
||||||
is_label: bool = False,
|
is_label: bool = False,
|
||||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||||
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
|
"""Split a packed seq/batch of padded sequences in a Zigzag fashion.
|
||||||
For each tensor in batch, return packed sequences if is_2d is False;
|
Different from split_batch_zigzag, inputs here have variable sequence lengths.
|
||||||
else return a padded batch of sequences.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
|
batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq,
|
||||||
|
where T is the total number of tokens.
|
||||||
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
|
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
|
||||||
sp_group (ProcessGroup): The process group for sequence parallelism.
|
sp_group (ProcessGroup): The process group for sequence parallelism.
|
||||||
max_seqlen (int): The maximum sequence length in the batch before splitting.
|
max_seqlen (int): The maximum sequence length in the batch before splitting.
|
||||||
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
|
is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len.
|
||||||
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
|
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
|
batch (List[torch.Tensor]): Packed sequences of shape (T, ..)
|
||||||
or (B, max_seqlen // sp_size, ...) if is_2d
|
or (B, max_seqlen // sp_size, ...) if is_batched_seq
|
||||||
"""
|
"""
|
||||||
sp_size = dist.get_world_size(sp_group)
|
sp_size = dist.get_world_size(sp_group)
|
||||||
sp_rank = dist.get_rank(sp_group)
|
sp_rank = dist.get_rank(sp_group)
|
||||||
if sp_size == 1:
|
if sp_size == 1:
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
if is_2d:
|
if is_batched_seq:
|
||||||
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
||||||
|
|
||||||
if isinstance(batch, torch.Tensor):
|
if isinstance(batch, torch.Tensor):
|
||||||
batch = [batch]
|
batch = [batch]
|
||||||
|
# seq: (B, Sq, h, n)
|
||||||
|
# seq = seq[:, :rank * (seqlen // sp_size), ...]
|
||||||
|
|
||||||
for i, packed_seq in enumerate(batch):
|
for i, packed_seq in enumerate(batch):
|
||||||
device = packed_seq.device
|
device = packed_seq.device
|
||||||
dtype = packed_seq.dtype
|
dtype = packed_seq.dtype
|
||||||
|
|
||||||
if is_2d:
|
if is_batched_seq:
|
||||||
assert max_seqlen % (sp_size * 2) == 0
|
assert max_seqlen % (sp_size * 2) == 0
|
||||||
# Recreate a padded tensor with the new max seqlen
|
# Recreate a padded tensor with the new max seqlen
|
||||||
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
|
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
|
||||||
|
@ -398,7 +400,7 @@ def split_varlen_zigzag(
|
||||||
seqlen % (2 * sp_size) == 0
|
seqlen % (2 * sp_size) == 0
|
||||||
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
|
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
|
||||||
|
|
||||||
if is_2d:
|
if is_batched_seq:
|
||||||
seq = packed_seq[j][:seqlen]
|
seq = packed_seq[j][:seqlen]
|
||||||
if is_label:
|
if is_label:
|
||||||
# Shift one position to the right for next token prediction
|
# Shift one position to the right for next token prediction
|
||||||
|
@ -415,7 +417,7 @@ def split_varlen_zigzag(
|
||||||
seq = seq.chunk(sp_size * 2)
|
seq = seq.chunk(sp_size * 2)
|
||||||
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
|
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
|
||||||
|
|
||||||
if is_2d:
|
if is_batched_seq:
|
||||||
batch[i] = local_seq.contiguous()
|
batch[i] = local_seq.contiguous()
|
||||||
else:
|
else:
|
||||||
batch[i] = torch.cat(local_seq, dim=0)
|
batch[i] = torch.cat(local_seq, dim=0)
|
||||||
|
|
|
@ -857,17 +857,17 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
|
||||||
dropout_p = self.attn_dropout.p if self.training else 0.0
|
dropout_p = self.attn_dropout.p if self.training else 0.0
|
||||||
|
|
||||||
sp_mode = shard_config.sequence_parallelism_mode
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
sp_group = shard_config.sequence_parallel_process_group
|
|
||||||
if sp_mode == "ring_attn":
|
if sp_mode == "ring_attn":
|
||||||
attn_output = RingAttention.attention(
|
attn_output = RingAttention.attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
sp_group,
|
sp_axis=shard_config.sp_axis,
|
||||||
**attention_mask,
|
**attention_mask,
|
||||||
dropout_p=dropout_p,
|
dropout_p=dropout_p,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
inner_ring_size=shard_config.inner_ring_size,
|
inner_ring_size=shard_config.inner_ring_size,
|
||||||
|
pg_mesh=shard_config.pg_mesh,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
||||||
|
|
|
@ -271,6 +271,7 @@ class LlamaPipelineForwards:
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -568,9 +569,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
sp_group,
|
sp_axis=shard_config.sp_axis,
|
||||||
**attention_mask,
|
**attention_mask,
|
||||||
inner_ring_size=shard_config.inner_ring_size,
|
inner_ring_size=shard_config.inner_ring_size,
|
||||||
|
pg_mesh=shard_config.pg_mesh,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif shard_config.enable_flash_attention:
|
elif shard_config.enable_flash_attention:
|
||||||
|
|
|
@ -73,7 +73,6 @@ class BertPolicy(Policy):
|
||||||
)
|
)
|
||||||
sp_mode = "split_gather"
|
sp_mode = "split_gather"
|
||||||
|
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
|
||||||
sp_partial_derived = sp_mode == "split_gather"
|
sp_partial_derived = sp_mode == "split_gather"
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
@ -97,7 +96,6 @@ class BertPolicy(Policy):
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -106,7 +104,6 @@ class BertPolicy(Policy):
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -115,7 +112,6 @@ class BertPolicy(Policy):
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -140,7 +136,6 @@ class BertPolicy(Policy):
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
|
||||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
|
|
|
@ -71,7 +71,7 @@ class BlipPolicy(Policy):
|
||||||
suffix="self_attn.qkv",
|
suffix="self_attn.qkv",
|
||||||
target_module=col_nn.FusedLinear1D_Col,
|
target_module=col_nn.FusedLinear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 3,
|
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
|
|
@ -57,7 +57,6 @@ class BloomPolicy(Policy):
|
||||||
)
|
)
|
||||||
sp_mode = "split_gather"
|
sp_mode = "split_gather"
|
||||||
|
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
|
||||||
sp_partial_derived = sp_mode == "split_gather"
|
sp_partial_derived = sp_mode == "split_gather"
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
@ -78,7 +77,6 @@ class BloomPolicy(Policy):
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -99,7 +97,6 @@ class BloomPolicy(Policy):
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
|
|
@ -67,7 +67,6 @@ class ChatGLMPolicy(Policy):
|
||||||
f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
|
f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
|
||||||
)
|
)
|
||||||
sp_mode = "split_gather"
|
sp_mode = "split_gather"
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
|
||||||
sp_partial_derived = sp_mode in ["split_gather"]
|
sp_partial_derived = sp_mode in ["split_gather"]
|
||||||
|
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
|
@ -127,7 +126,6 @@ class ChatGLMPolicy(Policy):
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"seq_parallel_dim": 0,
|
"seq_parallel_dim": 0,
|
||||||
"overlap": overlap,
|
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
|
|
@ -65,7 +65,6 @@ class GPT2Policy(Policy):
|
||||||
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||||
)
|
)
|
||||||
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
|
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
use_flash_attention = self.shard_config.enable_flash_attention
|
use_flash_attention = self.shard_config.enable_flash_attention
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
@ -92,9 +91,8 @@ class GPT2Policy(Policy):
|
||||||
suffix="attn.c_attn",
|
suffix="attn.c_attn",
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 3,
|
"split_sizes": [self.model.config.hidden_size] * 3,
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -107,9 +105,8 @@ class GPT2Policy(Policy):
|
||||||
suffix="mlp.c_fc",
|
suffix="mlp.c_fc",
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 1,
|
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
|
||||||
"seq_parallel_mode": sp_mode,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
|
||||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
|
|
|
@ -51,7 +51,6 @@ class GPTJPolicy(Policy):
|
||||||
self.shard_config.enable_sequence_parallelism = False
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
assert (
|
assert (
|
||||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
@ -76,7 +75,6 @@ class GPTJPolicy(Policy):
|
||||||
suffix="attn.k_proj",
|
suffix="attn.k_proj",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"overlap": overlap,
|
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -84,7 +82,6 @@ class GPTJPolicy(Policy):
|
||||||
suffix="attn.q_proj",
|
suffix="attn.q_proj",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"overlap": overlap,
|
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -92,7 +89,6 @@ class GPTJPolicy(Policy):
|
||||||
suffix="attn.v_proj",
|
suffix="attn.v_proj",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"overlap": overlap,
|
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
|
|
@ -42,7 +42,7 @@ class SamPolicy(Policy):
|
||||||
suffix="attn.qkv",
|
suffix="attn.qkv",
|
||||||
target_module=col_nn.FusedLinear1D_Col,
|
target_module=col_nn.FusedLinear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 3,
|
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
|
|
@ -26,7 +26,6 @@ class ShardConfig:
|
||||||
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
|
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
|
||||||
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
|
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
|
||||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
|
||||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
|
|
||||||
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
|
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
|
||||||
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
||||||
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
|
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
|
||||||
|
@ -44,13 +43,14 @@ class ShardConfig:
|
||||||
enable_jit_fused: bool = False
|
enable_jit_fused: bool = False
|
||||||
enable_sequence_parallelism: bool = False
|
enable_sequence_parallelism: bool = False
|
||||||
sequence_parallelism_mode: str = None
|
sequence_parallelism_mode: str = None
|
||||||
enable_sequence_overlap: bool = False
|
|
||||||
parallel_output: bool = True
|
parallel_output: bool = True
|
||||||
make_vocab_size_divisible_by: int = 64
|
make_vocab_size_divisible_by: int = 64
|
||||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
# For ring attention
|
# For ring attention
|
||||||
|
sp_axis: Optional[int] = None
|
||||||
|
pg_mesh: Optional[int] = None
|
||||||
inner_ring_size: Optional[int] = None
|
inner_ring_size: Optional[int] = None
|
||||||
# for moe related
|
# for moe related
|
||||||
moe_dp_group: Optional[ProcessGroup] = None
|
moe_dp_group: Optional[ProcessGroup] = None
|
||||||
|
@ -84,24 +84,12 @@ class ShardConfig:
|
||||||
assert (
|
assert (
|
||||||
self.enable_tensor_parallelism
|
self.enable_tensor_parallelism
|
||||||
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
|
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
|
||||||
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
|
||||||
# assert (
|
|
||||||
# not self.enable_tensor_parallelism
|
|
||||||
# ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
|
|
||||||
if self.enable_sequence_overlap:
|
|
||||||
self.enable_sequence_overlap = False
|
|
||||||
warnings.warn(
|
|
||||||
f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if self.sequence_parallelism_mode:
|
if self.sequence_parallelism_mode:
|
||||||
self.sequence_parallelism_mode = None
|
self.sequence_parallelism_mode = None
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False"
|
f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False"
|
||||||
)
|
)
|
||||||
assert (
|
|
||||||
not self.enable_sequence_overlap
|
|
||||||
), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True"
|
|
||||||
|
|
||||||
# get the tensor parallel size
|
# get the tensor parallel size
|
||||||
if not self.enable_tensor_parallelism:
|
if not self.enable_tensor_parallelism:
|
||||||
|
@ -134,4 +122,3 @@ class ShardConfig:
|
||||||
# This can cause non-in-place param sharding when used without ZeRO.
|
# This can cause non-in-place param sharding when used without ZeRO.
|
||||||
# It may also slow down training when seq len is small. Plz enable manually.
|
# It may also slow down training when seq len is small. Plz enable manually.
|
||||||
# self.enable_sequence_parallelism = True
|
# self.enable_sequence_parallelism = True
|
||||||
# self.enable_sequence_overlap = True
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from .common import (
|
||||||
ensure_path_exists,
|
ensure_path_exists,
|
||||||
free_storage,
|
free_storage,
|
||||||
get_current_device,
|
get_current_device,
|
||||||
|
get_non_persistent_buffers_set,
|
||||||
is_ddp_ignored,
|
is_ddp_ignored,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
|
@ -25,4 +26,5 @@ __all__ = [
|
||||||
"set_seed",
|
"set_seed",
|
||||||
"get_current_device",
|
"get_current_device",
|
||||||
"is_ddp_ignored",
|
"is_ddp_ignored",
|
||||||
|
"get_non_persistent_buffers_set",
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,10 +5,11 @@ import os
|
||||||
import random
|
import random
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable, Optional, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
|
|
||||||
|
@ -76,3 +77,34 @@ def set_seed(seed):
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def get_non_persistent_buffers_set(
|
||||||
|
module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
memo: a memo to store the set of modules already added to the result
|
||||||
|
prefix: a prefix that will be added to the name of the module
|
||||||
|
remove_duplicate: whether to remove the duplicated module instances in the result
|
||||||
|
or not
|
||||||
|
"""
|
||||||
|
|
||||||
|
if memo is None:
|
||||||
|
memo = set()
|
||||||
|
self_non_persistent_set = set()
|
||||||
|
if module not in memo:
|
||||||
|
if remove_duplicate:
|
||||||
|
memo.add(module)
|
||||||
|
self_non_persistent_set = set(
|
||||||
|
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
|
||||||
|
)
|
||||||
|
for name, sub_module in module._modules.items():
|
||||||
|
if sub_module is None:
|
||||||
|
continue
|
||||||
|
submodule_prefix = prefix + ("." if prefix else "") + name
|
||||||
|
child_non_persistent_set = get_non_persistent_buffers_set(
|
||||||
|
sub_module, memo, submodule_prefix, remove_duplicate
|
||||||
|
)
|
||||||
|
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
||||||
|
return self_non_persistent_set
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
|
||||||
|
import json
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import _TYPES
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tensornvme.async_file_io import AsyncFileWriter
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||||
|
_TYPES_INV = {v: k for k, v in _TYPES.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TensorInfo:
|
||||||
|
dtype: str
|
||||||
|
shape: List[int]
|
||||||
|
data_offsets: Tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PreparedData:
|
||||||
|
n: int
|
||||||
|
header_bytes: bytes
|
||||||
|
offset: int
|
||||||
|
|
||||||
|
|
||||||
|
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]:
|
||||||
|
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
|
||||||
|
|
||||||
|
tensors = []
|
||||||
|
metadata = {}
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
for name, tensor in sorted_data:
|
||||||
|
n = tensor.numel() * tensor.element_size()
|
||||||
|
tensor_info = TensorInfo(
|
||||||
|
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)
|
||||||
|
)
|
||||||
|
offset += n
|
||||||
|
metadata[name] = asdict(tensor_info)
|
||||||
|
tensors.append(tensor)
|
||||||
|
|
||||||
|
metadata_buf = json.dumps(metadata).encode("utf-8")
|
||||||
|
|
||||||
|
extra = (8 - len(metadata_buf) % 8) % 8
|
||||||
|
metadata_buf += b" " * extra
|
||||||
|
|
||||||
|
n = len(metadata_buf)
|
||||||
|
|
||||||
|
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors
|
||||||
|
|
||||||
|
|
||||||
|
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||||
|
prepared_data, tensors = prepare(state_dict)
|
||||||
|
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||||
|
|
||||||
|
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||||
|
f_writer.write(header_bytes)
|
||||||
|
|
||||||
|
for tensor in tensors:
|
||||||
|
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
|
@ -35,7 +35,7 @@ from colossalai.tensor.padded_tensor import (
|
||||||
to_unpadded_tensor,
|
to_unpadded_tensor,
|
||||||
)
|
)
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
|
from colossalai.utils import _cast_float, free_storage, get_non_persistent_buffers_set, is_ddp_ignored
|
||||||
|
|
||||||
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
||||||
from .gemini_hook import GeminiZeROHook
|
from .gemini_hook import GeminiZeROHook
|
||||||
|
@ -187,7 +187,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
|
self._non_persistent_buffers_set = get_non_persistent_buffers_set(module)
|
||||||
self._cast_buffers()
|
self._cast_buffers()
|
||||||
|
|
||||||
# register grad hook
|
# register grad hook
|
||||||
|
@ -257,36 +257,6 @@ class GeminiDDP(ModelWrapper):
|
||||||
for p in params_to_ignore:
|
for p in params_to_ignore:
|
||||||
p._ddp_to_ignore = True
|
p._ddp_to_ignore = True
|
||||||
|
|
||||||
def _get_non_persistent_buffers_set(
|
|
||||||
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
memo: a memo to store the set of modules already added to the result
|
|
||||||
prefix: a prefix that will be added to the name of the module
|
|
||||||
remove_duplicate: whether to remove the duplicated module instances in the result
|
|
||||||
or not
|
|
||||||
"""
|
|
||||||
|
|
||||||
if memo is None:
|
|
||||||
memo = set()
|
|
||||||
self_non_persistent_set = set()
|
|
||||||
if module not in memo:
|
|
||||||
if remove_duplicate:
|
|
||||||
memo.add(module)
|
|
||||||
self_non_persistent_set = set(
|
|
||||||
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
|
|
||||||
)
|
|
||||||
for name, sub_module in module._modules.items():
|
|
||||||
if sub_module is None:
|
|
||||||
continue
|
|
||||||
submodule_prefix = prefix + ("." if prefix else "") + name
|
|
||||||
child_non_persistent_set = self._get_non_persistent_buffers_set(
|
|
||||||
sub_module, memo, submodule_prefix, remove_duplicate
|
|
||||||
)
|
|
||||||
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
|
||||||
return self_non_persistent_set
|
|
||||||
|
|
||||||
def _post_forward(self):
|
def _post_forward(self):
|
||||||
"""This function is only triggered for inference."""
|
"""This function is only triggered for inference."""
|
||||||
access_list = list(self.chunk_manager.accessed_chunks)
|
access_list = list(self.chunk_manager.accessed_chunks)
|
||||||
|
|
|
@ -1,10 +1,5 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
|
|
||||||
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
|
|
||||||
GradMemStats,
|
|
||||||
GradMemTracerHook,
|
|
||||||
ParamMemTracerHook,
|
|
||||||
)
|
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.utils import _cast_float
|
from colossalai.utils import _cast_float
|
||||||
|
|
||||||
|
@ -27,6 +22,12 @@ class RuntimeMemTracer:
|
||||||
|
|
||||||
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
|
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
|
||||||
|
GradMemStats,
|
||||||
|
GradMemTracerHook,
|
||||||
|
ParamMemTracerHook,
|
||||||
|
)
|
||||||
|
|
||||||
self.module = module
|
self.module = module
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self._gradstat = GradMemStats()
|
self._gradstat = GradMemStats()
|
||||||
|
|
|
@ -8,7 +8,6 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
|
||||||
from colossalai.zero.gemini.chunk import Chunk
|
from colossalai.zero.gemini.chunk import Chunk
|
||||||
|
|
||||||
from .chunk import Chunk, ChunkManager
|
from .chunk import Chunk, ChunkManager
|
||||||
|
@ -172,6 +171,8 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
Returns:
|
Returns:
|
||||||
int: the volume of memory that is evicted
|
int: the volume of memory that is evicted
|
||||||
"""
|
"""
|
||||||
|
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||||
|
|
||||||
start = time()
|
start = time()
|
||||||
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
|
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
|
||||||
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
|
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
|
||||||
|
|
|
@ -25,15 +25,13 @@
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
## 新闻
|
## 新闻
|
||||||
|
* [2024/10] [How to build a low-cost Sora-like app? Solutions for you](https://company.hpc-ai.com/blog/how-to-build-a-low-cost-sora-like-app-solutions-for-you)
|
||||||
|
* [2024/09] [Singapore Startup HPC-AI Tech Secures 50 Million USD in Series A Funding to Build the Video Generation AI Model and GPU Platform](https://company.hpc-ai.com/blog/singapore-startup-hpc-ai-tech-secures-50-million-usd-in-series-a-funding-to-build-the-video-generation-ai-model-and-gpu-platform)
|
||||||
|
* [2024/09] [Reducing AI Large Model Training Costs by 30% Requires Just a Single Line of Code From FP8 Mixed Precision Training Upgrades](https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades)
|
||||||
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||||
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
||||||
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
||||||
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
||||||
* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
|
|
||||||
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
|
|
||||||
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
|
|
||||||
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
|
|
||||||
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
|
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
<ul>
|
<ul>
|
||||||
|
|
|
@ -16,7 +16,7 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan)
|
||||||
AMP stands for automatic mixed precision training.
|
AMP stands for automatic mixed precision training.
|
||||||
In Colossal-AI, we have incorporated different implementations of mixed precision training:
|
In Colossal-AI, we have incorporated different implementations of mixed precision training:
|
||||||
|
|
||||||
1. torch.cuda.amp
|
1. torch.amp
|
||||||
2. apex.amp
|
2. apex.amp
|
||||||
3. naive amp
|
3. naive amp
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
AMP 代表自动混合精度训练。
|
AMP 代表自动混合精度训练。
|
||||||
在 Colossal-AI 中, 我们结合了混合精度训练的不同实现:
|
在 Colossal-AI 中, 我们结合了混合精度训练的不同实现:
|
||||||
|
|
||||||
1. torch.cuda.amp
|
1. torch.amp
|
||||||
2. apex.amp
|
2. apex.amp
|
||||||
3. naive amp
|
3. naive amp
|
||||||
|
|
||||||
|
|
|
@ -163,6 +163,8 @@ def main():
|
||||||
enable_async_reduce=not args.disable_async_reduce,
|
enable_async_reduce=not args.disable_async_reduce,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
elif args.plugin == "gemini_auto":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
|
@ -177,6 +179,8 @@ def main():
|
||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
elif args.plugin == "fsdp":
|
elif args.plugin == "fsdp":
|
||||||
if use_empty_init:
|
if use_empty_init:
|
||||||
|
@ -188,6 +192,7 @@ def main():
|
||||||
),
|
),
|
||||||
param_init_fn=empty_init(),
|
param_init_fn=empty_init(),
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
plugin = TorchFSDPPlugin(
|
plugin = TorchFSDPPlugin(
|
||||||
|
@ -209,6 +214,7 @@ def main():
|
||||||
cpu_offload=CPUOffload(offload_params=True),
|
cpu_offload=CPUOffload(offload_params=True),
|
||||||
param_init_fn=empty_init(),
|
param_init_fn=empty_init(),
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
plugin = TorchFSDPPlugin(
|
plugin = TorchFSDPPlugin(
|
||||||
|
@ -219,6 +225,7 @@ def main():
|
||||||
),
|
),
|
||||||
cpu_offload=CPUOffload(offload_params=True),
|
cpu_offload=CPUOffload(offload_params=True),
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
elif args.plugin == "3d":
|
elif args.plugin == "3d":
|
||||||
if args.pp_style == "zbv":
|
if args.pp_style == "zbv":
|
||||||
|
|
|
@ -79,7 +79,7 @@ class _CppExtension(_Extension):
|
||||||
|
|
||||||
# check if the kernel has been built
|
# check if the kernel has been built
|
||||||
compiled_before = False
|
compiled_before = False
|
||||||
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
|
kernel_file_path = build_directory.joinpath(f"{self.name}.so")
|
||||||
if kernel_file_path.exists():
|
if kernel_file_path.exists():
|
||||||
compiled_before = True
|
compiled_before = True
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ class _CudaExtension(_CppExtension):
|
||||||
|
|
||||||
# check if the kernel has been built
|
# check if the kernel has been built
|
||||||
compiled_before = False
|
compiled_before = False
|
||||||
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
|
kernel_file_path = build_directory.joinpath(f"{self.name}.so")
|
||||||
if kernel_file_path.exists():
|
if kernel_file_path.exists():
|
||||||
compiled_before = True
|
compiled_before = True
|
||||||
|
|
||||||
|
|
|
@ -41,22 +41,7 @@ class Conv1D(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def rearrange(tensor: torch.Tensor, dim: int):
|
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str):
|
||||||
tensor = tensor.clone()
|
|
||||||
world_size = 2
|
|
||||||
order = torch.arange(world_size * 3)
|
|
||||||
new_order = []
|
|
||||||
for i in range(world_size):
|
|
||||||
new_order.append(order[i::world_size])
|
|
||||||
new_order = torch.cat(new_order)
|
|
||||||
|
|
||||||
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
|
|
||||||
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
|
|
||||||
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
|
|
||||||
return rearanged_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
|
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = Conv1D(192, 48).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
|
@ -66,8 +51,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
|
||||||
process_group=None,
|
process_group=None,
|
||||||
gather_output=True,
|
gather_output=True,
|
||||||
seq_parallel_mode=seq_parallel_mode,
|
seq_parallel_mode=seq_parallel_mode,
|
||||||
n_fused=3,
|
split_sizes=[64] * 3,
|
||||||
overlap=overlap,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([48, 192])
|
||||||
|
@ -88,13 +72,13 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
|
||||||
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
)
|
)
|
||||||
gather_out = linear_conv_col(x_for_shard)
|
gather_out = linear_conv_col(x_for_shard)
|
||||||
assert_close(rearrange(out, -1), gather_out)
|
assert_close(out, gather_out)
|
||||||
|
|
||||||
# check backward correctness
|
# check backward correctness
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
gather_out.sum().backward()
|
gather_out.sum().backward()
|
||||||
|
|
||||||
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
|
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [64] * 3, None, True)
|
||||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,9 +120,8 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
||||||
|
|
||||||
@parameterize("lazy_init", [False, True])
|
@parameterize("lazy_init", [False, True])
|
||||||
@parameterize("seq_parallel_mode", ["split_gather", None])
|
@parameterize("seq_parallel_mode", ["split_gather", None])
|
||||||
@parameterize("overlap", [True])
|
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
|
||||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
|
||||||
check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
|
|
||||||
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)
|
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,13 +2,12 @@ import os
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
|
||||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
@ -16,93 +15,55 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
|
||||||
"""
|
|
||||||
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
|
||||||
|
|
||||||
Basically works like a linear layer but the weights are transposed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
nf (`int`): The number of output features.
|
|
||||||
nx (`int`): The number of input features.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, nf, nx):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.weight = nn.Parameter(torch.empty(nx, nf))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(nf))
|
|
||||||
nn.init.normal_(self.weight, std=0.02)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
size_out = x.size()[:-1] + (self.nf,)
|
|
||||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
|
||||||
x = x.view(size_out)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def rearrange(tensor: torch.Tensor, dim: int):
|
|
||||||
tensor = tensor.clone()
|
|
||||||
world_size = 2
|
|
||||||
order = torch.arange(world_size * 3)
|
|
||||||
new_order = []
|
|
||||||
for i in range(world_size):
|
|
||||||
new_order.append(order[i::world_size])
|
|
||||||
new_order = torch.cat(new_order)
|
|
||||||
|
|
||||||
tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
|
|
||||||
rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
|
|
||||||
rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
|
|
||||||
return rearanged_tensor
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize("lazy_init", [False, True])
|
@parameterize("lazy_init", [False, True])
|
||||||
def check_linear_conv_1d_col(lazy_init: bool):
|
def check_linear_1d_col(lazy_init: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = nn.Linear(8, 80).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = Conv1D(192, 48).cuda()
|
linear_copy = nn.Linear(8, 80).cuda()
|
||||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
|
linear_col = FusedLinear1D_Col.from_native_module(
|
||||||
linear_copy, process_group=None, gather_output=True, n_fused=3
|
linear_copy, process_group=None, gather_output=True, split_sizes=[32, 32, 16]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([80, 8])
|
||||||
assert linear.bias.shape == torch.Size([192])
|
assert linear.bias.shape == torch.Size([80])
|
||||||
assert linear_conv_col.weight.shape == torch.Size([48, 96])
|
assert linear_col.weight.shape == torch.Size([40, 8])
|
||||||
assert linear_conv_col.bias.shape == torch.Size([96])
|
assert linear_col.bias.shape == torch.Size([40])
|
||||||
assert linear_copy.weight is linear_conv_col.weight
|
assert linear_copy.weight is linear_col.weight
|
||||||
assert linear_copy.bias is linear_conv_col.bias
|
assert linear_copy.bias is linear_col.bias
|
||||||
|
|
||||||
# ensure weights are reversibly loadable
|
# ensure weights are reversibly loadable
|
||||||
linear_conv_col.load_state_dict(linear.state_dict())
|
linear_col.load_state_dict(linear.state_dict())
|
||||||
linear.load_state_dict(linear_conv_col.state_dict())
|
linear.load_state_dict(linear_col.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 48).cuda()
|
x = torch.rand(4, 8).cuda()
|
||||||
out = linear(x)
|
out = linear(x)
|
||||||
gather_out = linear_conv_col(x)
|
gather_out = linear_col(x)
|
||||||
assert_close(rearrange(out, 1), gather_out)
|
assert_close(out, gather_out)
|
||||||
|
|
||||||
# check backward correctness
|
# check backward correctness
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
gather_out.sum().backward()
|
gather_out.sum().backward()
|
||||||
|
|
||||||
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
|
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, False)
|
||||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
assert_close(target_grad, linear_col.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
@parameterize("lazy_init", [False, True])
|
@parameterize("lazy_init", [False, True])
|
||||||
def check_linear_conv_1d_row(lazy_init: bool):
|
def check_linear_1d_row(lazy_init: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = nn.Linear(80, 8).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = Conv1D(192, 48).cuda()
|
linear_copy = nn.Linear(80, 8).cuda()
|
||||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
linear_row = FusedLinear1D_Row.from_native_module(
|
||||||
|
linear_copy, process_group=None, split_sizes=[32, 32, 16], parallel_input=False
|
||||||
|
)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([8, 80])
|
||||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
assert linear_row.weight.shape == torch.Size([8, 40])
|
||||||
assert linear_row.bias.shape == torch.Size([192])
|
assert linear_row.bias.shape == torch.Size([8])
|
||||||
assert linear_copy.weight is linear_row.weight
|
assert linear_copy.weight is linear_row.weight
|
||||||
assert linear_copy.bias is linear_row.bias
|
assert linear_copy.bias is linear_row.bias
|
||||||
|
|
||||||
|
@ -111,7 +72,7 @@ def check_linear_conv_1d_row(lazy_init: bool):
|
||||||
linear.load_state_dict(linear_row.state_dict())
|
linear.load_state_dict(linear_row.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 48).cuda()
|
x = torch.rand(4, 80).cuda()
|
||||||
out = linear(x)
|
out = linear(x)
|
||||||
gather_out = linear_row(x)
|
gather_out = linear_row(x)
|
||||||
assert_close(out, gather_out)
|
assert_close(out, gather_out)
|
||||||
|
@ -120,17 +81,51 @@ def check_linear_conv_1d_row(lazy_init: bool):
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
gather_out.sum().backward()
|
gather_out.sum().backward()
|
||||||
|
|
||||||
rank = dist.get_rank()
|
target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, True)
|
||||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
|
|
||||||
assert_close(target_grad, linear_row.weight.grad)
|
assert_close(target_grad, linear_row.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("lazy_init", [False, True])
|
||||||
|
def check_linear_1d_col_row(lazy_init: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
linear1 = nn.Linear(8, 80).cuda()
|
||||||
|
linear2 = nn.Linear(80, 8).cuda()
|
||||||
|
with ctx:
|
||||||
|
linear1_copy = nn.Linear(8, 80).cuda()
|
||||||
|
linear2_copy = nn.Linear(80, 8).cuda()
|
||||||
|
linear_col = FusedLinear1D_Col.from_native_module(linear1_copy, process_group=None, split_sizes=[32, 32, 16])
|
||||||
|
linear_row = FusedLinear1D_Row.from_native_module(
|
||||||
|
linear2_copy,
|
||||||
|
process_group=None,
|
||||||
|
split_sizes=[32, 32, 16],
|
||||||
|
)
|
||||||
|
# ensure weights are reversibly loadable
|
||||||
|
linear_col.load_state_dict(linear1.state_dict())
|
||||||
|
linear_row.load_state_dict(linear2.state_dict())
|
||||||
|
|
||||||
|
# check computation correctness
|
||||||
|
x = torch.rand(4, 8).cuda()
|
||||||
|
target_out = linear2(linear1(x))
|
||||||
|
out = linear_row(linear_col(x))
|
||||||
|
assert_close(out, target_out)
|
||||||
|
|
||||||
|
# check backward correctness
|
||||||
|
target_out.sum().backward()
|
||||||
|
out.sum().backward()
|
||||||
|
|
||||||
|
target_grad1 = split_fused_qkv_in_gpt2_style(linear1.weight.grad, [32, 32, 16], None, False)
|
||||||
|
assert_close(target_grad1, linear_col.weight.grad)
|
||||||
|
target_grad2 = split_fused_qkv_in_gpt2_style(linear2.weight.grad, [32, 32, 16], None, True)
|
||||||
|
assert_close(target_grad2, linear_row.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
|
||||||
# test for linear conv
|
check_linear_1d_col()
|
||||||
check_linear_conv_1d_col()
|
check_linear_1d_row()
|
||||||
check_linear_conv_1d_row()
|
check_linear_1d_col_row()
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
|
@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.shardformer.layer import AttnMaskType
|
from colossalai.shardformer.layer import AttnMaskType
|
||||||
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
|
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
|
||||||
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
|
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
|
||||||
|
@ -17,11 +18,14 @@ from colossalai.utils import get_current_device
|
||||||
@parameterize("nheads", [5])
|
@parameterize("nheads", [5])
|
||||||
@parameterize("d", [128])
|
@parameterize("d", [128])
|
||||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||||
def check_ring_attn(seq_len, bs, nheads, d, dtype):
|
def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
|
||||||
torch.cuda.manual_seed(2)
|
torch.cuda.manual_seed(2)
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
sp_group = dist.group.WORLD
|
sp_group = dist.group.WORLD
|
||||||
|
dp_size, pp_size, tp_size = 1, 1, 1
|
||||||
sp_size = dist.get_world_size()
|
sp_size = dist.get_world_size()
|
||||||
|
sp_axis = 2
|
||||||
|
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
|
||||||
# Some outliers may seem large, but our errors are still lower than
|
# Some outliers may seem large, but our errors are still lower than
|
||||||
# than Megatron-LM context parallel's
|
# than Megatron-LM context parallel's
|
||||||
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
|
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
|
||||||
|
@ -40,11 +44,11 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
sp_group,
|
sp_axis,
|
||||||
AttnMaskType.CAUSAL,
|
AttnMaskType.CAUSAL,
|
||||||
return_softmax=True,
|
return_softmax=True,
|
||||||
inner_ring_size=max(2, sp_size // 2),
|
inner_ring_size=inner_ring_size,
|
||||||
# inner_ring_size=4
|
pg_mesh=pg_mesh,
|
||||||
)
|
)
|
||||||
ring_out = ring_out.transpose(1, 2)
|
ring_out = ring_out.transpose(1, 2)
|
||||||
out, lse, _ = flash_attn_qkvpacked_func(
|
out, lse, _ = flash_attn_qkvpacked_func(
|
||||||
|
@ -83,6 +87,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
sp_group = dist.group.WORLD
|
sp_group = dist.group.WORLD
|
||||||
sp_size = dist.get_world_size()
|
sp_size = dist.get_world_size()
|
||||||
|
sp_axis = 2
|
||||||
atol = rtol = 7e-3
|
atol = rtol = 7e-3
|
||||||
torch.cuda.manual_seed(2)
|
torch.cuda.manual_seed(2)
|
||||||
# Prepare varlen attention mask
|
# Prepare varlen attention mask
|
||||||
|
@ -123,10 +128,11 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
|
||||||
q_ring,
|
q_ring,
|
||||||
k_ring,
|
k_ring,
|
||||||
v_ring,
|
v_ring,
|
||||||
sp_group,
|
sp_axis,
|
||||||
**mask_info,
|
**mask_info,
|
||||||
pad_output=False,
|
pad_output=False,
|
||||||
return_softmax=True,
|
return_softmax=True,
|
||||||
|
pg_mesh=ProcessGroupMesh(1, 1, sp_size, 1),
|
||||||
# deterministic=True
|
# deterministic=True
|
||||||
)
|
)
|
||||||
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
|
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
|
||||||
|
@ -161,12 +167,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
|
||||||
def launch_single_ring(rank, world_size, port):
|
def launch_single_ring(rank, world_size, port):
|
||||||
colossalai.launch(rank, world_size, "localhost", port)
|
colossalai.launch(rank, world_size, "localhost", port)
|
||||||
check_packed_seq()
|
check_packed_seq()
|
||||||
check_ring_attn()
|
check_ring_attn(inner_ring_size=None)
|
||||||
|
|
||||||
|
|
||||||
def launch_double_ring(rank, world_size, port):
|
def launch_double_ring(rank, world_size, port):
|
||||||
colossalai.launch(rank, world_size, "localhost", port)
|
colossalai.launch(rank, world_size, "localhost", port)
|
||||||
check_ring_attn()
|
check_ring_attn(inner_ring_size=2)
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
0.4.4
|
0.4.5
|
||||||
|
|
Loading…
Reference in New Issue