Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into rlhf_SimPO

pull/5850/head
YeAnbang 2024-07-10 02:32:07 +00:00
commit 16f3451fe2
165 changed files with 4783 additions and 4722 deletions

View File

@ -90,7 +90,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch
timeout-minutes: 90
defaults:
run:
@ -165,6 +165,7 @@ jobs:
env:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
- name: Collate artifact
env:

View File

@ -13,7 +13,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 90
steps:
- name: Check GPU Availability # ensure all GPUs have enough memory
@ -69,6 +69,7 @@ jobs:
env:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
- name: Notify Lark
id: message-preparation

View File

@ -50,7 +50,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 200
steps:
- name: Install dependencies
@ -92,3 +92,4 @@ jobs:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors

View File

@ -41,7 +41,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 200
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
@ -87,3 +87,4 @@ jobs:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors

View File

@ -38,7 +38,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 200
steps:
- name: Install dependencies
@ -85,6 +85,7 @@ jobs:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
- name: Notify Lark
id: message-preparation

View File

@ -1,34 +1,34 @@
repos:
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.1
rev: v2.3.1
hooks:
- id: autoflake
name: autoflake (python)
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
name: sort all imports (python)
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.9.1
rev: 24.4.2
hooks:
- id: black
name: black formatter
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.1
rev: v18.1.8
hooks:
- id: clang-format
name: clang formatter
types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.6.0
hooks:
- id: check-yaml
- id: check-merge-conflict

View File

@ -83,15 +83,19 @@ class DataCollatorForSupervisedDataset(object):
# `List[torch.Tensor]`
batch_input_ids = [
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
(
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
)
for instance in instances
]
batch_labels = [
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
(
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
)
for instance in instances
]
if self.tokenizer.padding_side == "right":

View File

@ -1,6 +1,7 @@
"""
loss functions
"""
from typing import Optional, Tuple
import torch

View File

@ -1,6 +1,7 @@
"""
reward model
"""
from typing import Optional
import torch

View File

@ -1,6 +1,7 @@
"""
Training utilities for Coati.
"""
from typing import Any
import torch

View File

@ -78,7 +78,9 @@ def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict
option_string = "ABCDEFG"
count = len(line["options"])
input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
input = (
"问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
)
all_classes = list(option_string[0:count])
@ -150,7 +152,15 @@ def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=F
)
elif dataset_name in chinese_qa_datasets:
question_input = (
"问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label)
"问题:"
+ passage
+ " "
+ question
+ "\n"
+ "从以下选项中选择:"
+ " ".join(options)
+ "\n"
+ "答案:{}".format(label)
)
elif dataset_name in english_cloze_datasets:
question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)

View File

@ -57,7 +57,11 @@ ceval_subject_mapping = {
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
"accountant": ["Accountant", "注册会计师", "Other"],
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
"environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"],
"environmental_impact_assessment_engineer": [
"Environmental Impact Assessment Engineer",
"环境影响评价工程师",
"Other",
],
"tax_accountant": ["Tax Accountant", "税务师", "Other"],
"physician": ["Physician", "医师资格", "Other"],
}

View File

@ -56,9 +56,11 @@ class MTBenchDataset(BaseDataset):
"instruction": question["turns"],
"input": "",
"output": [],
"target": [""] * turn_number
if question["question_id"] not in reference
else reference[question["question_id"]],
"target": (
[""] * turn_number
if question["question_id"] not in reference
else reference[question["question_id"]]
),
}
if category in dataset["test"]:

View File

@ -77,7 +77,9 @@ class HuggingFaceModel(BaseModel):
self.indices_for_choices[0].append(
self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
)
self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1])
self.indices_for_choices[1].append(
self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]
)
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
"""

View File

@ -1,92 +0,0 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from colossalai.lazy import LazyInitContext
from colossalai.moe import MOE_MANAGER
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config):
super().__init__(config)
self.setup_ep()
def setup_ep(self):
_, moe_info = MOE_MANAGER.get_info(self.num_experts)
ep_group = moe_info.ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
set_moe_tensor_info(p, moe_info)
@staticmethod
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
module.setup_ep()
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
selected_experts = selected_experts.t().reshape(-1)
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device
)
dispatch_states = dispatch_states[recover_experts_idx]
k_hidden_states = dispatch_states.chunk(self.top_k)
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
for i in range(1, self.top_k):
output_states += k_hidden_states[i] * routing_weights[:, i, None]
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
return output_states, router_logits

View File

@ -2,8 +2,6 @@ import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
@ -70,8 +68,6 @@ def main():
ep_size=ep_size,
zero_stage=1,
precision=args.precision,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
)

View File

@ -1,5 +1,6 @@
NUM_GPU=2
MODEL="mistralai/Mixtral-8x7B-v0.1"
# MODEL="mistralai/Mixtral-8x7B-v0.1"
MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
# ep
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \

View File

@ -1,146 +0,0 @@
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from torch.optim import Adam
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert torch.equal(p1.half(), p2.half())
def get_optimizer_snapshot(optim):
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
param_groups = []
for group in optim.param_groups:
params = [id(p) for p in group["params"]]
new_group = {"params": params}
for k, v in group.items():
if k != "params":
new_group[k] = v
param_groups.append(new_group)
return {
"state": state,
"param_groups": param_groups,
}
def check_optimizer_snapshot_equal(snapshot1, snapshot2):
# check param_groups
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
assert set(group1.keys()) == set(group2.keys())
for k in group1.keys():
assert group1[k] == group2[k]
# check state
assert set(snapshot1["state"].keys()) == set(
snapshot2["state"].keys()
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
for pid in snapshot1["state"].keys():
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
assert set(state1.keys()) == set(state2.keys())
for k in state1.keys():
if isinstance(state1[k], torch.Tensor):
assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
else:
assert state1[k] == state2[k]
def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
num_attention_heads=2,
num_key_value_heads=2,
)
torch.manual_seed(0)
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
orig_model = MixtralForCausalLM(config).cuda()
model = deepcopy(orig_model)
optimizer = Adam(model.parameters(), lr=1e-3)
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=2,
ep_size=2,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
microbatch_size=1,
zero_stage=1,
)
booster = Booster(plugin=plugin)
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
# initialize grads
data_iter = iter(
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
)
booster.execute_pipeline(
data_iter,
model,
lambda outputs, inputs: outputs.loss,
optimizer,
)
# check save model
booster.save_model(model, "mixtral_model", shard=True)
dist.barrier()
if dist.get_rank() == 0:
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
check_model_equal(orig_model, saved_model)
saved_model.save_pretrained("mixtral_hf_model")
dist.barrier()
# check load model
new_model = MixtralForCausalLM(config).cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, "mixtral_hf_model")
check_model_equal(model, new_model)
# check save optimizer
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
snapshot = get_optimizer_snapshot(optimizer.unwrap())
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
dist.barrier()
# reset optimizer state
for state in optimizer.unwrap().state.values():
for v in state.values():
if isinstance(v, torch.Tensor):
v.zero_()
booster.load_optimizer(optimizer, "mixtral_optim")
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch(rank, world_size, "localhost", port)
check_mixtral_moe_layer()
@pytest.mark.parametrize("world_size", [4])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mixtral_moe_layer(4)

View File

@ -2,13 +2,11 @@ import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralForCausalLM
from utils import load_checkpoint, move_to_cuda, save_checkpoint
import colossalai
from colossalai.booster import Booster
@ -155,12 +153,10 @@ def main():
pp_size=args.pp_size,
ep_size=args.ep_size,
microbatch_size=args.microbatch_size,
custom_policy=MixtralForCausalLMPolicy(),
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
precision=args.precision,
zero_stage=args.zero_stage,
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
)
else:

View File

@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
from __future__ import annotations
import copy

View File

@ -8,6 +8,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
import copy
from typing import Any, Mapping, Optional, Protocol

View File

@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
import copy
from typing import Any, List

View File

@ -2,7 +2,6 @@
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
"""
import glob
import os

View File

@ -20,6 +20,7 @@ resp = llm(prompt='What do you call a three-ton kangaroo?', **gen_config)
print(resp) # super-heavyweight awesome-natured yawning Australian creature!
"""
import json
from typing import Any, Mapping

View File

@ -12,6 +12,7 @@ TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
"""
from typing import Any, List, Mapping, Optional
import torch

View File

@ -1,6 +1,7 @@
"""
Generation utilities
"""
import json
from typing import List

View File

@ -2,6 +2,7 @@
Implement a memory class for storing conversation history
Support long term and short term memory
"""
from typing import Any, Dict, List
from colossalqa.chain.memory.summary import ConversationSummaryMemory

View File

@ -1,6 +1,7 @@
"""
Class for logging with extra control for debugging
"""
import logging

View File

@ -1,6 +1,7 @@
"""
Script for Chinese retrieval based conversation system backed by ChatGLM
"""
from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA

View File

@ -1,6 +1,7 @@
"""
Multilingual retrieval based conversation system
"""
from typing import List
from colossalqa.data_loader.document_loader import DocumentLoader

View File

@ -1,6 +1,7 @@
"""
Script for Chinese retrieval based conversation system backed by ChatGLM
"""
from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA

View File

@ -1,6 +1,7 @@
"""
Code for custom retriver with incremental update
"""
import copy
import hashlib
import os

View File

@ -1,6 +1,7 @@
"""
Code for Chinese text splitter
"""
from typing import Any, List, Optional
from colossalqa.text_splitter.utils import get_cleaned_paragraph

View File

@ -1,6 +1,7 @@
"""
Script for English retrieval based conversation system backed by LLaMa2
"""
import argparse
import os

View File

@ -1,6 +1,7 @@
"""
Script for English retrieval based conversation system backed by LLaMa2
"""
import argparse
import json
import os

View File

@ -1,6 +1,7 @@
"""
Script for Chinese retrieval based conversation system backed by ChatGLM
"""
import argparse
import os

View File

@ -1,6 +1,7 @@
"""
Script for English retrieval based conversation system backed by LLaMa2
"""
import argparse
import os

View File

@ -107,20 +107,22 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes(weight_tensor),
parameter=(
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
),
temp=0,
buffer=0,
)
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes(weight_tensor),
activation=(
compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes([input_tensor, weight_tensor])
),
parameter=(
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
),
temp=0,
buffer=0,
)

View File

@ -1,10 +1,18 @@
from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
__all__ = [
"Plugin",
"TorchDDPPlugin",
"GeminiPlugin",
"LowLevelZeroPlugin",
"HybridParallelPlugin",
"MoeHybridParallelPlugin",
]
import torch
from packaging import version

View File

@ -655,7 +655,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
@ -718,7 +717,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
"""Retrieve all working gradients from different parameter groups."""
all_working_grads = []
for group_id in range(self.num_param_groups):
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
working_grads = self.get_working_grads_by_group_id(group_id)
all_working_grads.extend(working_grads)
return all_working_grads
@ -726,7 +725,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
"""Identify gradients to be synchronized in the sequence parallelism."""
grads_to_sync = []
for grad in all_working_grads:
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
param_id_for_grad = self.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
grads_to_sync.append(grad)
@ -739,7 +738,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads)
if self._grad_store.require_grad_sync and grads_to_sync is not None:
if self.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
else:
@ -763,7 +762,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
@ -788,14 +787,14 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
@ -811,7 +810,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
if len(gradients) == 0:
return 0.0
dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1
dp_size = get_world_size(dp_pg) if dp_pg is not None else 1
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
@ -842,7 +841,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
param_id_for_grad = self.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if not is_distributed_tensor(param_for_grad):
@ -856,7 +855,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param))
working_grad = self.get_working_grad_by_param_id(id(stage_shared_param))
if grad is working_grad:
grad_norm_exponentiated /= len(shared_param)
@ -867,7 +866,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
)
if dp_size > 1:
# compute norm in dp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg)
if tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
@ -1206,6 +1205,7 @@ class HybridParallelPlugin(PipelinePluginBase):
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
@ -1309,7 +1309,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
):
return outputs

View File

@ -1,4 +1,5 @@
import random
import warnings
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
@ -20,19 +21,19 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
get_param_info,
init_pipeline_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MOE_MANAGER, MoECheckpointIO
from colossalai.logging import get_dist_logger
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
@ -67,8 +68,20 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
pg_param_list = {
dp_process_group: [],
moe_extra_dp_process_group: [],
}
for param in model.parameters():
if is_moe_tensor(param):
pg_param_list[moe_extra_dp_process_group].append(param)
else:
pg_param_list[dp_process_group].append(param)
super().__init__(
optimizer=optimizer,
pg_to_param_list=pg_param_list,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
@ -83,9 +96,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
moe_extra_dp_process_group=moe_extra_dp_process_group,
)
@ -107,8 +118,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
@ -144,14 +155,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params.
"""
def __init__(
self,
tp_size: int,
pp_size: int,
ep_size: int,
extra_dp_size: int = 1,
tp_size: int = 1,
sp_size: int = 1,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
@ -184,32 +196,22 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
custom_policy: Policy = None,
checkpoint_io: Optional[MoECheckpointIO] = None,
) -> None:
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
world_size = dist.get_world_size()
assert tp_size == 1, "Tensor parallel is not supported in MoE yet"
assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet"
if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
world_size % (tp_size * pp_size) == 0
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=self.real_dp_size,
fixed_ep_size=ep_size,
fixed_pp_size=pp_size,
use_ep_inside=use_ep_inside,
)
world_size % (tp_size * pp_size * ep_size) == 0
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
self.dp_size = world_size // (tp_size * pp_size)
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.ep_size = ep_size
self.moe_info = MOE_MANAGER.get_info(0)[1]
self.sp_size = sp_size
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
@ -219,43 +221,57 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.checkpoint_io = checkpoint_io
logger = get_dist_logger()
# NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param
# See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient
# we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
assert (
self.ep_size <= self.dp_size
), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})."
# sync moe in outer dp group, and sync other param in global dp group
if extra_dp_size > 1:
ep_size = self.dp_size // extra_dp_size
if use_ep_inside:
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size)
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1)
if dist.get_rank() == 0:
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}")
else:
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size)
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2)
if dist.get_rank() == 0:
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}")
self.moe_dp_size = self.dp_size // self.ep_size
self.use_ep_inside = use_ep_inside
if self.use_ep_inside:
logger.info(f"MoE Parallel use ep inside dp.", ranks=[0])
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size)
else:
self.moe_extra_dp_group = None
logger.info(f"MoE Parallel use ep outside dp.", ranks=[0])
warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.")
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size)
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0])
logger.info(
f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0]
)
self.tp_group = self.pg_mesh.get_group_along_axis(
self.tp_axis
) # TODO: support custom tp size for mixtral lm head
self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis))
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
# TODO: Currently moe only support partially sequence parallel
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.custom_policy = custom_policy
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
# TODO: Currently moe only support partially sequence parallel
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
@ -267,6 +283,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
ep_group=self.ep_group,
)
self.amp_config = dict(
initial_scale=initial_scale,
@ -323,7 +340,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
dataset,
num_replicas=self.dp_size,
rank=dist.get_rank(self.global_dp_group),
shuffle=shuffle,
)
# Deterministic dataloader
@ -346,9 +366,20 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
def get_checkpoint_io(self) -> MoECheckpointIO:
if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
self.checkpoint_io = MoECheckpointIO(
self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
)
else:
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
self.checkpoint_io = self.checkpoint_io(
self.global_dp_group,
self.pp_group,
self.tp_group,
ep_group=self.ep_group,
moe_dp_group=self.moe_dp_group,
zero_stage=self.zero_stage,
)
if hasattr(self.checkpoint_io, "moe_info"):
self.checkpoint_io.moe_info = self.moe_info
return self.checkpoint_io
def configure(
@ -366,7 +397,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
dp_group=self.global_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
@ -392,15 +423,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group,
dp_process_group=self.global_dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_extra_dp_process_group=self.moe_extra_dp_group,
moe_extra_dp_process_group=self.moe_dp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,

View File

@ -2,5 +2,12 @@ from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile
from .moe_checkpoint import MoECheckpointIO
__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
__all__ = [
"CheckpointIO",
"CheckpointIndexFile",
"GeneralCheckpointIO",
"HybridParallelCheckpointIO",
"MoECheckpointIO",
]

View File

@ -70,13 +70,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
verbose: bool = True,
) -> None:
super().__init__()
self.dp_group = dp_group
self.global_dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
self.dp_rank = dist.get_rank(self.dp_group)
self.dp_rank = dist.get_rank(self.global_dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
self.dp_size = dist.get_world_size(dp_group)
self.global_dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
self.use_zero = zero_stage > 0
@ -433,7 +433,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.dp_group,
dp_group=self.global_dp_group,
tp_group=self.tp_group,
size_per_shard=size_per_shard,
)
@ -727,7 +727,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state,
working_param,
original_shape=original_shape,
dp_group=self.dp_group,
dp_group=self.global_dp_group,
tp_group=self.tp_group,
use_zero=self.use_zero,
inplace=False,
@ -932,12 +932,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Shard state along data parallel group when using Zero.
if self.use_zero:
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.dp_size
slice_size = v.numel() // self.global_dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)

View File

@ -9,6 +9,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import get_global_rank
from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
@ -19,15 +20,16 @@ from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
load_state_dict,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
try:
@ -36,21 +38,30 @@ except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
class MoECheckpointIO(HybridParallelCheckpointIO):
def __init__(
self,
dp_group: ProcessGroup,
global_dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_dp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
self.ep_group = moe_info.ep_group
self.ep_size = moe_info.ep_size
self.ep_rank = moe_info.ep_rank
self.real_dp_rank = moe_info.dp_rank
super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose)
self.global_dp_group = global_dp_group
self.global_dp_rank = dist.get_rank(global_dp_group)
self.global_dp_size = dist.get_world_size(global_dp_group)
self.pp_group = pp_group
self.tp_group = tp_group
self.moe_dp_group = moe_dp_group
self.moe_dp_size = dist.get_world_size(moe_dp_group)
self.moe_dp_rank = dist.get_rank(moe_dp_group)
self.ep_group = ep_group
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
@staticmethod
def _model_sharder(
@ -134,7 +145,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
Path(checkpoint).mkdir(parents=True, exist_ok=True)
if self.real_dp_rank != 0:
if self.moe_dp_rank != 0:
dist.barrier()
return
@ -144,7 +155,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
state_dict_shard = MoECheckpointIO._model_sharder(
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
@ -234,11 +245,12 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
state: OrderedDict,
param: torch.Tensor,
original_shape: torch.Size,
dp_group: ProcessGroup,
global_dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_zero: bool,
inplace: bool,
is_moe_param: bool,
moe_dp_group: ProcessGroup = None,
device: torch.device = torch.device("cpu"),
) -> OrderedDict:
"""
@ -248,7 +260,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
dp_group (ProcessGroup): The process group of data parallel.
global_dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
@ -257,27 +269,47 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
dp_size = dist.get_world_size(dp_group)
global_dp_size = dist.get_world_size(global_dp_group)
tp_size = dist.get_world_size(tp_group)
moe_dp_size = dist.get_world_size(moe_dp_group) if moe_dp_group is not None else 1
current_shape = param.shape
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
v = v.cuda()
# First gather Zero shards.
if use_zero and not is_moe_param:
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
if use_zero and is_moe_param and moe_dp_size > 1:
moe_dp_rank = dist.get_rank(moe_dp_group)
dst = get_global_rank(moe_dp_group, 0)
if moe_dp_rank == 0:
gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
dist.gather(v, gather_tensor, group=moe_dp_group, dst=dst)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
else:
dist.gather(v, group=moe_dp_group, dst=dst)
elif use_zero and not is_moe_param and global_dp_size > 1:
dp_rank = dist.get_rank(global_dp_group)
dst = get_global_rank(global_dp_group, 0)
if dp_rank == 0:
gather_tensor = [torch.zeros_like(v) for _ in range(global_dp_size)]
dist.gather(v, gather_tensor, group=global_dp_group, dst=dst)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
else:
dist.gather(v, group=global_dp_group, dst=dst)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
if partition_dim is not None:
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)
tp_rank = dist.get_rank(tp_group)
dst = get_global_rank(tp_group, 0)
if tp_rank == 0:
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
dist.gather(v, gather_tensor, group=tp_group, dst=dst)
v = torch.cat(gather_tensor, dim=partition_dim)
else:
dist.gather(v, group=tp_group, dst=dst)
state_[k] = v.detach().clone().to(device)
return state_
@ -286,8 +318,9 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
def _optimizer_sharder(
optimizer: OptimizerWrapper,
use_zero: bool,
dp_group: ProcessGroup,
global_dp_group: ProcessGroup,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
size_per_shard: int = 1024,
only_moe_param: bool = False,
):
@ -296,7 +329,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
master_to_working_map = optimizer.get_master_to_working_map()
for param, state in optimizer.optim.state.items():
if param is None:
continue
@ -305,22 +337,23 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
working_param = master_to_working_map[id(param)]
else:
working_param = param
param_id = param_info["param2id"][id(working_param)]
original_shape = param_info["param2shape"][id(working_param)]
state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state_ = MoECheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
global_dp_group=global_dp_group,
moe_dp_group=moe_dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False,
is_moe_param=is_moe_tensor(working_param),
is_moe_param=is_moe_tensor(working_param), # TODO: Check correctness here
)
if only_moe_param and not is_moe_tensor(working_param):
continue
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
yield block, block_size
@ -359,25 +392,28 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if not self.use_zero and self.real_dp_rank != 0:
# If optim states are not sharded, other ranks don't need to participate in gather.
if not self.use_zero and self.moe_dp_rank != 0:
dist.barrier()
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
state_dict_shard = MoECheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.dp_group,
global_dp_group=self.global_dp_group,
tp_group=self.tp_group,
moe_dp_group=self.moe_dp_group,
size_per_shard=size_per_shard,
only_moe_param=self.ep_rank != 0,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
# e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather
# rank 0 saves moe & non-moe params; rank 1 only saves moe params
# rank 3 & 4 save nothing
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
@ -596,7 +632,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
OrderedDict: The sharded optimizer state of the given parameter.
"""
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
@ -606,24 +641,218 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
# Shard state along data parallel group when using Zero.
if self.use_zero and not is_moe_param:
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
if self.use_zero and not is_moe_param and self.global_dp_size > 1:
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
slice_size = v.numel() // self.global_dp_size
v = v.split(slice_size, dim=0)[self.global_dp_rank]
elif self.use_zero and is_moe_param and self.moe_dp_size > 1:
# LowLevelZeRO pads by global dp size for now.
# TODO: update both to use moe dp size
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.moe_dp_size
v = v.split(slice_size, dim=0)[self.moe_dp_rank]
state_[k] = v.detach().clone().to(device)
return state_
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
raise NotImplementedError
"""Migration from MoEHybridParallelCheckpointIO. These functions mostly deals with unsharded saving,
and can be savely deleted since large MoE models are often saved in shards.
"""
# Copied from colossalai.moe
def pre_save_model(self, model: nn.Module) -> dict:
state_dict = model.state_dict()
for name, param in model.named_parameters():
if ".experts." in name and is_moe_tensor(param):
ep_group = param.ep_group
ep_rank = dist.get_rank(ep_group)
ep_size = dist.get_world_size(ep_group)
# TODO: check correctness here
# dp_rank = get_dp_rank(param)
dp_rank = dist.get_rank(self.global_dp_group)
if dp_rank == 0:
param = param.data.cuda()
if ep_rank == 0:
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
else:
all_param = None
# gather param from every ep rank
# dist.all_gather(all_param, param, group=ep_group)
dist.gather(param, all_param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu()
if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.gather_object(state_dict, out, group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:
new_state_dict.update(o)
state_dict = new_state_dict
dist.barrier()
return state_dict
def save_unsharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
):
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
dist.barrier()
# Copied from colossalai.moe
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
raise NotImplementedError
"""
Save optimizer state dict to a file with given path.
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
checkpoint (str): Path to save optimizer state_dict.
gather_dtensor (bool): Whether to gather_dtensor, not used.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
# optimizer states of parameters kept by local device('s pipeline stage)
local_states = dict()
for param, state in optimizer.optim.state.items():
if param is None:
continue
# working param is needed for obtaining correct param_id
master_to_working_map = optimizer.get_master_to_working_map()
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
# gather complete state from tp shards & dp shards
param_id = optimizer.param_info["param2id"][id(working_param)]
local_states[param_id] = self.pre_save_optim(
state,
working_param,
inplace=False,
device=torch.device("cuda"),
)
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
states_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group)
# dist.all_gather_object(states_list, local_states, self.pp_group)
dist.gather_object(local_states, states_list, self.pp_group)
# Only the master rank do the saving.
if self.coordinator.is_master():
state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
dist.barrier()
# Copied from colossalai.moe
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
raise NotImplementedError
"""
Load optimizer from a file with given path.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the checkpoint file.
"""
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
if id(working_param) in optimizer.param_info["param2id"]:
return optimizer.param_info["param2id"][id(working_param)]
else:
None
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
state_dict = load_state_dict(checkpoint)
# Load param_groups.
updated_groups = []
saved_groups = state_dict["param_groups"]
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
updated_groups.append(new_pg)
# ep extra group
# if MOE_MANAGER.parallel == "EP":
if self.ep_size > 1:
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1][
"params"
] # Only keep the parameters kept by current pipeline stage.
for param in new_pg["params"]:
param.data = param.data.to(torch.float32)
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
master_to_working_map = optimizer.get_master_to_working_map()
id_map = {}
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id is not None:
id_map[param_id] = param
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
if param is None:
continue
device = param.device
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
param,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True,
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
dist.barrier()

View File

@ -242,6 +242,7 @@ def save_state_dict_shards(
shard_filenames = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
# Just loop over the sharder and gather to other ranks if not master
if not is_master:
del shard
continue

View File

@ -147,7 +147,7 @@ class ProcessGroupMesh:
ProcessGroup: The process group with the given ranks.
"""
ranks_in_group = sorted(ranks_in_group)
if tuple(ranks_in_group) not in self._group_to_ranks:
if tuple(ranks_in_group) not in self._ranks_to_group:
group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group
self._group_to_ranks[group] = tuple(ranks_in_group)
@ -244,19 +244,25 @@ class ProcessGroupMesh:
return target_group
def get_group_along_axis(
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
axis (int): Axis along which the process groups are created.
axis (int or list of int): Axes along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
indices_at_axis = indices_at_axis
if indices_at_axis is None:
if isinstance(axis, (list, tuple)):
indices_at_axis = list(list(range(self._shape[ax])) for ax in axis)
else:
indices_at_axis = list(range(self._shape[axis]))
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if ranks_in_group not in self._ranks_to_group:

View File

@ -247,16 +247,16 @@ class BatchBucket:
self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
# TODO external (rename): modify Sequence.sentence_len to seq_len
self._sequence_lengths[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
)
# NOTE block tables to be updated by kvcache manager
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
if alloc_block_tables is not None:
# copy block ids from provided block tables
self._block_tables[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = alloc_block_tables
self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
alloc_block_tables
)
elif alloc_block_tables_fn:
alloc_block_tables_fn(
block_tables,

View File

@ -1,10 +1,11 @@
"""
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers.generation import GenerationConfig
@ -82,9 +83,9 @@ class InputMetaData(RPC_PARAM):
dtype: torch.dtype = torch.float32
use_spec_dec: bool = False
num_tokens_to_verify: int = 0
batch_token_ids: Optional[
List[List[int]]
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
batch_token_ids: Optional[List[List[int]]] = (
None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
)
def to_rpc_param(self) -> Dict[str, any]:
return {
@ -202,9 +203,9 @@ class InferenceConfig(RPC_PARAM):
prompt_template: Optional[str] = None
do_sample: bool = False
beam_width: int = 1 # TODO: beam search is not support for now
prefill_ratio: Optional[
float
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
prefill_ratio: Optional[float] = (
1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
)
pad_input: bool = False
early_stopping: Optional[bool] = False
top_k: Optional[int] = 50
@ -234,7 +235,9 @@ class InferenceConfig(RPC_PARAM):
high_precision: Optional[bool] = False
# cuda_graph
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph: bool = (
False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
)
max_context_len_to_capture: int = 512
# StreamingLLM (sliding window attention with attention sinks)
@ -393,3 +396,49 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
@dataclass
class DiffusionGenerationConfig:
"""
Param for diffusion model forward
"""
prompt_2: Optional[Union[str, List[str]]] = None
prompt_3: Optional[Union[str, List[str]]] = None
height: Optional[int] = None
width: Optional[int] = None
num_inference_steps: int = None
timesteps: List[int] = None
guidance_scale: float = None
negative_prompt: Optional[Union[str, List[str]]] = (
None # NOTE(@lry89757) in pixart default to "", in sd3 default to None
)
negative_prompt_2: Optional[Union[str, List[str]]] = None
negative_prompt_3: Optional[Union[str, List[str]]] = None
num_images_per_prompt: Optional[int] = None
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
latents: Optional[torch.FloatTensor] = None
prompt_embeds: Optional[torch.FloatTensor] = None
negative_prompt_embeds: Optional[torch.FloatTensor] = None
pooled_prompt_embeds: Optional[torch.FloatTensor] = None
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
output_type: Optional[str] = None # "pil"
return_dict: bool = None
joint_attention_kwargs: Optional[Dict[str, Any]] = None
clip_skip: Optional[int] = None
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
callback_on_step_end_tensor_inputs: List[str] = None
def to_dict(self) -> Dict[str, Any]:
# NOTE(@lry89757) Only return the dict that not the default value None
result = {}
for field in fields(self):
value = getattr(self, field.name)
if value is not None:
result[field.name] = value
return result
@classmethod
def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
return cls(**kwargs)

View File

@ -0,0 +1,90 @@
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
class BaseEngine(ABC):
@abstractmethod
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
pass
@abstractmethod
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
"""
Init Model for Engine
"""
@abstractmethod
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
"""
Generate ouptput for coming requests
"""
@abstractmethod
def add_request(self, prompts, request_ids=None, **kwargs):
"""
Add new request to Engine
"""
@abstractmethod
def step(self):
"""
Perform one new step forward
"""
@abstractmethod
def _verify_args(self):
"""
Verify the parameters and members of class
"""
@torch.inference_mode()
def capture_model(self):
"""
Use cuda graph to capture model
"""
return NotImplementedError("This method should be implemented by subclasses")
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
**kwargs,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.
Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: The model optimized by Shardformer.
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model

View File

@ -0,0 +1,200 @@
from itertools import count
from typing import List, Tuple, Type, Union
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from torch import distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import DiffusionSequence
from colossalai.inference.utils import get_model_size, get_model_type
from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy
from .base_engine import BaseEngine
from .request_handler import NaiveRequestHandler
PP_AXIS, TP_AXIS = 0, 1
class DiffusionEngine(BaseEngine):
def __init__(
self,
model_or_path: DiffusionPipeline | str,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Policy | type[Policy] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
self.high_precision = inference_config.high_precision
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.model_type = get_model_type(model_or_path=model_or_path)
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.request_handler = NaiveRequestHandler()
self.counter = count()
self._verify_args()
def _verify_args(self) -> None:
assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe"
def init_model(
self,
model_or_path: Union[str, nn.Module, DiffusionPipeline],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
if isinstance(model_or_path, str):
model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype)
policy_map_key = model.__class__.__name__
model = DiffusionPipe(model)
elif isinstance(model_or_path, DiffusionPipeline):
policy_map_key = model_or_path.__class__.__name__
model = DiffusionPipe(model_or_path)
else:
self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!")
torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]
self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")
if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)
if model_policy is None:
model_policy = model_policy_map.get(policy_map_key)
if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
self.model = model.to(self.device)
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)
def generate(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
generation_config: DiffusionGenerationConfig = None,
**kwargs,
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
""" """
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
with torch.inference_mode():
if prompts is not None:
self.add_request(
request_ids=request_ids,
prompts=prompts,
**gen_config_dict,
**kwargs,
)
output_reqs_list = []
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict
while self.request_handler.check_unfinished_reqs():
output_reqs_list += self.step()
return output_reqs_list
def add_request(
self,
prompts: Union[List[str], str],
request_ids: Union[List[int], int] = None,
**kwargs,
):
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if not isinstance(prompts, list):
prompts = [prompts]
generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs)
prompts_num = len(prompts)
for i in range(prompts_num):
if request_ids:
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
else:
request_id = next(self.counter)
seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config)
self.request_handler.add_sequence(seq)
def step(self) -> List[PIL.Image.Image]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. run forward to get List[Image]
Returns:
List[PIL.Image.Image]: Image Generated by one step.
"""
input = self.request_handler.schedule()
ret = self.model(prompt=input.prompt, **input.generation_config.to_dict())
return ret

View File

@ -1,58 +1,24 @@
import time
from itertools import count
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import List, Tuple, Type, Union
import numpy as np
import torch
import PIL.Image
import torch.nn as nn
from torch import distributed as dist
from transformers import (
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from diffusers import DiffusionPipeline
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.inference.config import InferenceConfig
from colossalai.inference.utils import ModelType, get_model_type
from colossalai.shardformer.policies.base_policy import Policy
from .request_handler import RequestHandler
__all__ = ["InferenceEngine"]
PP_AXIS, TP_AXIS = 0, 1
_supported_models = {
"LlamaForCausalLM": LlamaForCausalLM,
"BaichuanForCausalLM": AutoModelForCausalLM,
}
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class InferenceEngine:
"""
InferenceEngine which manages the inference process..
Args:
model_or_path (nn.Module or str): Path or nn.Module of this model.
model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
@ -61,567 +27,68 @@ class InferenceEngine:
def __init__(
self,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
model_or_path: Union[nn.Module, str, DiffusionPipeline],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Union[Policy, Type[Policy]] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
self.high_precision = inference_config.high_precision
self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__
self.model_type = get_model_type(model_or_path=model_or_path)
self.engine = None
if self.model_type == ModelType.LLM:
from .llm_engine import LLMEngine
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.engine = LLMEngine(
model_or_path=model_or_path,
tokenizer=tokenizer,
inference_config=inference_config,
verbose=verbose,
model_policy=model_policy,
)
elif self.model_type == ModelType.DIFFUSION_MODEL:
from .diffusion_engine import DiffusionEngine
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?
self.counter = count()
self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")
self.capture_model(self.k_cache, self.v_cache)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = self.inference_config.use_spec_dec
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.engine = DiffusionEngine(
model_or_path=model_or_path,
inference_config=inference_config,
verbose=verbose,
model_policy=model_policy,
)
elif self.model_type == ModelType.UNKNOWN:
self.logger.error(f"Model Type either Difffusion or LLM!")
self._initialized = True
self._verify_args()
def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
pretrained_path = None
if isinstance(model_or_path, str):
import colossalai.interface.pretrained as pretrained_utils
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys():
if arch is "BaichuanForCausalLM":
self.logger.warning(
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
)
ctx = LazyInitContext(default_device="cuda")
with ctx:
model = _supported_models[arch].from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
pretrained_path = pretrained_utils.get_pretrained_path(model)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
except Exception as e:
self.logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
)
else:
model = model_or_path
self.model_config = model.config
torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]
self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")
model = model.to(self.dtype).eval()
if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)
if model_policy is None:
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
model_policy = model_policy_map.get(model_policy_key)
if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
self.model = ModelWrapper(model).to(self.device)
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
if pretrained_path:
from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(pretrained_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)
@torch.inference_mode()
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
assert self.use_cuda_graph, "please turn on the cuda graph"
if self.verbose:
self.logger.info("Colossal AI CUDA Graph Capture begin")
t_capture_begin = time.perf_counter()
block_size = self.inference_config.block_size
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
self.graph_block_tables[0, :] = np.arange(
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
)
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
sequence_lengths[0] = torch.tensor(
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
if self.verbose:
self.logger.info(f"batch size {batch_size} graph capturing")
input_meta_data = InputMetaData(
block_tables=block_tables[:batch_size],
sequence_lengths=sequence_lengths[:batch_size],
fd_inter_tensor=fd_inter_tensor,
batch_size=batch_size,
is_prompts=False,
use_cuda_graph=True,
high_precision=False,
kv_seq_len=sequence_lengths[:batch_size].max().item(),
head_dim=head_dim,
dtype=self.dtype,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens_ids[:batch_size],
output_tensor[:batch_size],
input_meta_data,
k_caches=k_cache,
v_caches=v_cache,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
t_capture_end = time.perf_counter()
if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
def _verify_args(self) -> None:
"""Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig):
raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
if isinstance(self.model, ModelWrapper):
model = self.model.module
assert (
model.__class__.__name__ in _supported_models.keys()
), f"Model {self.model.__class__.__name__} is not supported."
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.
Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: The model optimized by Shardformer.
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model
def enable_spec_dec(
self,
drafter_model: nn.Module = None,
n_spec_tokens: int = None,
use_glide_drafter: bool = False,
) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
Args:
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
If True, the drafter model will be replaced by a glide model.
```python
...
engine = InferenceEngine(model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
engine.generate(...) # Speculative Decoding
engine.disable_spec_dec()
engine.generate(...) # Normal generation
engine.enable_spec_dec()
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
engine.clear_spec_dec()
```
"""
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
self.n_spec_tokens = n_spec_tokens
if drafter_model is not None:
assert isinstance(drafter_model, nn.Module)
# overwrite the drafter, if exists
self.clear_spec_dec()
self.drafter_model = drafter_model
self.drafter = Drafter(
self.drafter_model,
self.tokenizer,
device=self.device,
dtype=self.dtype,
)
# check if the provided drafter model is compatible with GLIDE structure
# when `use_glide_drafter` is set to True
if (
use_glide_drafter
and hasattr(drafter_model, "model")
and hasattr(drafter_model.model, "layers")
and hasattr(drafter_model.model.layers[0], "cross_attn")
):
self.use_glide = use_glide_drafter
elif use_glide_drafter:
self.logger.warning(
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
f"but the provided drafter model is not compatible with GLIDE structure."
f"Falling back to use the default drafter model (non-GLIDE)."
)
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
def disable_spec_dec(self) -> None:
"""Disable using speculative decoding for subsequent generations."""
self.request_handler.unset_spec_dec_mode()
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_glide = False
self.use_spec_dec = False
def clear_spec_dec(self) -> None:
"""Clear relatable structures of speculative decoding, if exist."""
if self.use_spec_dec:
self.disable_spec_dec()
if self.drafter_model or self.drafter:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_glide = False
self.use_spec_dec = False
def steps_spec_dec(self) -> List[Sequence]:
"""
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
with many steps of speculating by a drafter model as well as verifying by a main model.
Returns:
List[Sequence]: finished sequences generated by one step.
"""
batch = self.request_handler.schedule() # prefill batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
already_allocated_kv_len = batch.seq_lengths[0].item()
input_token_ids = batch.get_1D_inputs_spec_dec(1)
finished_sequences = self.request_handler.update()
while True:
# HACK Retrieve the running batch
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
batch = self.request_handler.running_bb # running batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
# 3. Decoding - Drafter model speculates `n` tokens
glide_input = None
if self.use_glide:
glide_input = GlideInput(
batch.get_block_table_tensor(),
self.k_cache[-1], # use kv cahces of the last layer
self.v_cache[-1],
batch.get_sequence_lengths(),
n_spec_tokens=self.n_spec_tokens,
)
drafter_out = self.drafter.speculate(
input_token_ids,
self.n_spec_tokens,
drafter_past_key_values,
glide_input=glide_input,
)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
drafter_spec_length = drafter_out.speculated_length
for next_token_id_spec in next_token_ids_spec:
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
cur_length = batch.seq_lengths[0].item()
if already_allocated_kv_len < cur_length:
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
already_allocated_kv_len = cur_length
# 4. Decoding - Main model verifies `n` tokens in parallel
if drafter_spec_length < batch.num_tokens_to_verify:
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(
drafter_past_key_values, drafter_spec_length - n_matches - 1
)
# prepare inputs for the next round of speculation
n = 1 if n_matches < drafter_spec_length else 2
input_token_ids = batch.get_1D_inputs_spec_dec(n)
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
finished_sequences = self.request_handler.update()
if len(finished_sequences) > 0:
break
# Reset back the number of speculated tokens of the batch,
# this is used to handle the last round of speculation, in which case the number of speculated tokens
# by the drafter is less than the number of speculated tokens set to the engine.
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
return finished_sequences
assert self.engine is not None, "Please init Engine first"
assert self._initialized, "Engine must be initialized"
def generate(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
*args,
**kwargs,
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
"""
Executing the inference step.
Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
"""
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None:
self.add_request(
request_ids=request_ids,
prompts=prompts,
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)
output_seqs_list = []
total_tokens_list = []
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict
if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.steps_spec_dec()
else:
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
for seq in output_seqs_list:
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
if return_token_ids:
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
return output_str, output_tokens_list
else:
return output_str
@property
def has_prompt_template(self) -> bool:
""" """
return self.inference_config.prompt_template is not None
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
"""
This method will format the input prompt according to the prompt template given to the InferenceConfig.
"""
assert (
self.has_prompt_template
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
if isinstance(prompts, (list, tuple)):
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
elif isinstance(prompts, str):
return self.inference_config.prompt_template.format(input_text=prompts)
else:
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
assert self.engine is not None, "Please init Engine first"
return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs)
def add_request(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
*args,
**kwargs,
) -> None:
"""
@ -631,168 +98,36 @@ class InferenceEngine:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
kwargs: for LLM, it could be max_length, max_new_tokens, etc
for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers
"""
assert self.engine is not None, "Please init Engine first"
self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs)
# apply the prompt template to the input prompts
def step(self):
assert self.engine is not None, "Please init Engine first"
return self.engine.step()
if self.has_prompt_template and prompts is not None:
prompts = self.format_prompt(prompts)
block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]
# list of torch Tensor
if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:
raise TypeError(
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
)
assert (
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
prompts_num = len(prompts_token_ids)
for i in range(prompts_num):
if request_ids:
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
def __getattr__(self, name):
"""
The Design logic of getattr, setattr:
1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine.
2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx
So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine)
"""
if self.__dict__.get("_initialized", False):
if name in self.__dict__:
return self.__dict__[name]
else:
request_id = next(self.counter)
if prompts == None:
prompt = None
return getattr(self.engine, name)
else:
return self.__dict__[name]
def __setattr__(self, name, value):
if self.__dict__.get("_initialized", False):
if name in self.__dict__:
self.__dict__[name] = value
else:
prompt = prompts[i]
max_length = kwargs.get("max_length", None)
max_new_tokens = kwargs.get("max_new_tokens", None)
if max_length is None and max_new_tokens is None:
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])
if not self.inference_config.enable_streamingllm:
assert (
self.inference_config.max_output_len >= max_new_tokens
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
max_output_len=max_new_tokens,
ignore_eos=self.inference_config.ignore_eos,
)
self.request_handler.add_sequence(sequence)
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
n_tokens = sequence_lengths.sum().item()
setattr(self.engine, name, value)
else:
n_tokens = batch.current_batch_size
if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
n_tokens = n_tokens * batch.current_batch_size
output_tensor = torch.zeros(
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
batch_token_ids = None
if (
self.generation_config.repetition_penalty != 1.0
or self.generation_config.no_repeat_ngram_size > 0
or self.generation_config.forced_eos_token_id is not None
):
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True
input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=batch.fd_inter_tensor,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_kernel=self.inference_config.use_cuda_kernel,
use_cuda_graph=use_cuda_graph,
high_precision=self.high_precision,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)
return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Get the input, inputinfo and output placeholder from the batchbucket
3. Run model to generate the next token
4. Update waiting list and running list in RequestHandler and get finished sequences.
5. Decode and return finished sequences.
Returns:
List[str]: Decoded finished sequences generated by one step.
"""
batch = self.request_handler.schedule()
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
if self.inference_config.enable_streamingllm:
updated_block_ids = batch.streamingllm_update_batch(
self.inference_config.start_token_size, self.inference_config.generated_token_size
)
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()
return finished_sequences
self.__dict__[name] = value

View File

@ -0,0 +1,758 @@
import time
from itertools import count
from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch
import torch.nn as nn
from torch import distributed as dist
from transformers import (
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy
from .base_engine import BaseEngine
from .request_handler import RequestHandler
PP_AXIS, TP_AXIS = 0, 1
_supported_models = {
"LlamaForCausalLM": LlamaForCausalLM,
"BaichuanForCausalLM": AutoModelForCausalLM,
}
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class LLMEngine(BaseEngine):
"""
InferenceEngine which manages the inference process..
Args:
model_or_path (nn.Module or str): Path or nn.Module of this model.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
"""
def __init__(
self,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Union[Policy, type[Policy]] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
self.high_precision = inference_config.high_precision
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?
self.counter = count()
self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")
self.capture_model(self.k_cache, self.v_cache)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = self.inference_config.use_spec_dec
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self._verify_args()
def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
pretrained_path = None
if isinstance(model_or_path, str):
import colossalai.interface.pretrained as pretrained_utils
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys():
if arch == "BaichuanForCausalLM":
self.logger.warning(
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
)
ctx = LazyInitContext(default_device="cuda")
with ctx:
model = _supported_models[arch].from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
pretrained_path = pretrained_utils.get_pretrained_path(model)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
except Exception as e:
self.logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
)
else:
model = model_or_path
self.model_config = model.config
torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]
self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")
model = model.to(self.dtype).eval()
if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)
if model_policy is None:
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
model_policy = model_policy_map.get(model_policy_key)
if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
self.model = ModelWrapper(model).to(self.device)
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
if pretrained_path:
from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(pretrained_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)
@torch.inference_mode()
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
assert self.use_cuda_graph, "please turn on the cuda graph"
if self.verbose:
self.logger.info("Colossal AI CUDA Graph Capture begin")
t_capture_begin = time.perf_counter()
block_size = self.inference_config.block_size
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
self.graph_block_tables[0, :] = np.arange(
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
)
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
sequence_lengths[0] = torch.tensor(
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
if self.verbose:
self.logger.info(f"batch size {batch_size} graph capturing")
input_meta_data = InputMetaData(
block_tables=block_tables[:batch_size],
sequence_lengths=sequence_lengths[:batch_size],
fd_inter_tensor=fd_inter_tensor,
batch_size=batch_size,
is_prompts=False,
use_cuda_graph=True,
high_precision=False,
kv_seq_len=sequence_lengths[:batch_size].max().item(),
head_dim=head_dim,
dtype=self.dtype,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens_ids[:batch_size],
output_tensor[:batch_size],
input_meta_data,
k_caches=k_cache,
v_caches=v_cache,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
t_capture_end = time.perf_counter()
if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
def _verify_args(self) -> None:
"""Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig):
raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
if isinstance(self.model, ModelWrapper):
model = self.model.module
assert (
model.__class__.__name__ in _supported_models.keys()
), f"Model {self.model.__class__.__name__} is not supported."
def enable_spec_dec(
self,
drafter_model: nn.Module = None,
n_spec_tokens: int = None,
use_glide_drafter: bool = False,
) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
Args:
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
If True, the drafter model will be replaced by a glide model.
```python
...
engine = InferenceEngine(model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
engine.generate(...) # Speculative Decoding
engine.disable_spec_dec()
engine.generate(...) # Normal generation
engine.enable_spec_dec()
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
engine.clear_spec_dec()
```
"""
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
self.n_spec_tokens = n_spec_tokens
if drafter_model is not None:
assert isinstance(drafter_model, nn.Module)
# overwrite the drafter, if exists
self.clear_spec_dec()
self.drafter_model = drafter_model
self.drafter = Drafter(
self.drafter_model,
self.tokenizer,
device=self.device,
dtype=self.dtype,
)
# check if the provided drafter model is compatible with GLIDE structure
# when `use_glide_drafter` is set to True
if (
use_glide_drafter
and hasattr(drafter_model, "model")
and hasattr(drafter_model.model, "layers")
and hasattr(drafter_model.model.layers[0], "cross_attn")
):
self.use_glide = use_glide_drafter
elif use_glide_drafter:
self.logger.warning(
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
f"but the provided drafter model is not compatible with GLIDE structure."
f"Falling back to use the default drafter model (non-GLIDE)."
)
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
def disable_spec_dec(self) -> None:
"""Disable using speculative decoding for subsequent generations."""
self.request_handler.unset_spec_dec_mode()
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_glide = False
self.use_spec_dec = False
def clear_spec_dec(self) -> None:
"""Clear relatable structures of speculative decoding, if exist."""
if self.use_spec_dec:
self.disable_spec_dec()
if self.drafter_model or self.drafter:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_glide = False
self.use_spec_dec = False
def steps_spec_dec(self) -> List[Sequence]:
"""
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
with many steps of speculating by a drafter model as well as verifying by a main model.
Returns:
List[Sequence]: finished sequences generated by one step.
"""
batch = self.request_handler.schedule() # prefill batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
already_allocated_kv_len = batch.seq_lengths[0].item()
input_token_ids = batch.get_1D_inputs_spec_dec(1)
finished_sequences = self.request_handler.update()
while True:
# HACK Retrieve the running batch
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
batch = self.request_handler.running_bb # running batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
# 3. Decoding - Drafter model speculates `n` tokens
glide_input = None
if self.use_glide:
glide_input = GlideInput(
batch.get_block_table_tensor(),
self.k_cache[-1], # use kv cahces of the last layer
self.v_cache[-1],
batch.get_sequence_lengths(),
n_spec_tokens=self.n_spec_tokens,
)
drafter_out = self.drafter.speculate(
input_token_ids,
self.n_spec_tokens,
drafter_past_key_values,
glide_input=glide_input,
)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
drafter_spec_length = drafter_out.speculated_length
for next_token_id_spec in next_token_ids_spec:
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
cur_length = batch.seq_lengths[0].item()
if already_allocated_kv_len < cur_length:
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
already_allocated_kv_len = cur_length
# 4. Decoding - Main model verifies `n` tokens in parallel
if drafter_spec_length < batch.num_tokens_to_verify:
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(
drafter_past_key_values, drafter_spec_length - n_matches - 1
)
# prepare inputs for the next round of speculation
n = 1 if n_matches < drafter_spec_length else 2
input_token_ids = batch.get_1D_inputs_spec_dec(n)
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
finished_sequences = self.request_handler.update()
if len(finished_sequences) > 0:
break
# Reset back the number of speculated tokens of the batch,
# this is used to handle the last round of speculation, in which case the number of speculated tokens
# by the drafter is less than the number of speculated tokens set to the engine.
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
return finished_sequences
def generate(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
"""
Executing the inference step.
Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
"""
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None:
self.add_request(
request_ids=request_ids,
prompts=prompts,
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)
output_seqs_list = []
total_tokens_list = []
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict
if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
while self.request_handler.check_unfinished_reqs():
output_seqs_list += self.steps_spec_dec()
else:
while self.request_handler.check_unfinished_reqs():
output_seqs_list += self.step()
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
for seq in output_seqs_list:
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
if return_token_ids:
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
return output_str, output_tokens_list
else:
return output_str
@property
def has_prompt_template(self) -> bool:
""" """
return self.inference_config.prompt_template is not None
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
"""
This method will format the input prompt according to the prompt template given to the InferenceConfig.
"""
assert (
self.has_prompt_template
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
if isinstance(prompts, (list, tuple)):
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
elif isinstance(prompts, str):
return self.inference_config.prompt_template.format(input_text=prompts)
else:
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
def add_request(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
) -> None:
"""
Add requests.
Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
"""
# apply the prompt template to the input prompts
if self.has_prompt_template and prompts is not None:
prompts = self.format_prompt(prompts)
block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]
# list of torch Tensor
if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:
raise TypeError(
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
)
assert (
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
prompts_num = len(prompts_token_ids)
for i in range(prompts_num):
if request_ids:
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
else:
request_id = next(self.counter)
if prompts == None:
prompt = None
else:
prompt = prompts[i]
max_length = kwargs.get("max_length", None)
max_new_tokens = kwargs.get("max_new_tokens", None)
if max_length is None and max_new_tokens is None:
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])
if not self.inference_config.enable_streamingllm:
assert (
self.inference_config.max_output_len >= max_new_tokens
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
max_output_len=max_new_tokens,
ignore_eos=self.inference_config.ignore_eos,
)
self.request_handler.add_sequence(sequence)
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
n_tokens = sequence_lengths.sum().item()
else:
n_tokens = batch.current_batch_size
if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
n_tokens = n_tokens * batch.current_batch_size
output_tensor = torch.zeros(
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
batch_token_ids = None
if (
self.generation_config.repetition_penalty != 1.0
or self.generation_config.no_repeat_ngram_size > 0
or self.generation_config.forced_eos_token_id is not None
):
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True
input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=batch.fd_inter_tensor,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_kernel=self.inference_config.use_cuda_kernel,
use_cuda_graph=use_cuda_graph,
high_precision=self.high_precision,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)
return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Get the input, inputinfo and output placeholder from the batchbucket
3. Run model to generate the next token
4. Update waiting list and running list in RequestHandler and get finished sequences.
5. Decode and return finished sequences.
Returns:
List[str]: Decoded finished sequences generated by one step.
"""
batch = self.request_handler.schedule()
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
if self.inference_config.enable_streamingllm:
updated_block_ids = batch.streamingllm_update_batch(
self.inference_config.start_token_size, self.inference_config.generated_token_size
)
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()
return finished_sequences

View File

@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
@ -98,7 +98,46 @@ class RunningList:
self._decoding[seq_id] = self._prefill.pop(seq_id)
class RequestHandler:
class NaiveRequestHandler:
def __init__(self) -> None:
self.running_list: List[DiffusionSequence] = []
self.waiting_list: List[str] = []
def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
def _has_running(self) -> bool:
return any(lst for lst in self.running_list)
def check_unfinished_reqs(self):
return self._has_waiting() or self._has_running()
def add_sequence(self, seq: DiffusionSequence):
"""
Add the request to waiting list.
"""
assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
self.waiting_list.append(seq)
def _find_sequence(self, request_id: int) -> DiffusionSequence:
"""
Find the request by request_id.
"""
for lst in enumerate(self.waiting_list + self.running_list):
for seq in lst:
if seq.request_id == request_id:
return seq
return None
def schedule(self):
ret = None
if self._has_waiting:
ret = self.waiting_list[0]
self.waiting_list = self.waiting_list[1:]
return ret
class RequestHandler(NaiveRequestHandler):
"""
RequestHandler is the core for handling existing requests and updating current batch.
During generation process, we call schedule function each iteration to update current batch.
@ -176,12 +215,12 @@ class RequestHandler:
generated_token_size=inference_config.generated_token_size,
)
def _has_running(self) -> bool:
return not self.running_bb.is_empty()
def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)
def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
def get_kvcache(self):
return self.cache_manager.get_kv_cache()
@ -318,7 +357,7 @@ class RequestHandler:
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
seq.mark_finished()
def check_unfinished_seqs(self) -> bool:
def check_unfinished_reqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
def total_requests_in_batch_bucket(self) -> int:

View File

@ -34,7 +34,6 @@ def run_server(host, port, event: mp.Event = None):
class RPCInferenceEngine(InferenceEngine):
"""
InferenceEngine which manages the inference process..

View File

@ -42,7 +42,6 @@ logger = get_dist_logger(__name__)
class rpcWorkerService(rpyc.Service):
"""
Execute the computation tasks and manage its own kv cache

View File

@ -279,9 +279,11 @@ class KVCacheManager:
block.add_ref()
self._allocate_on_block(
block,
block.block_size
if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size,
(
block.block_size
if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size
),
)
for block_id in alloc_block_ids:
if block_id in alloc_block_ids[last_block_locs]:

View File

@ -0,0 +1,54 @@
import inspect
import types
import torch
from torch import nn
class DiffusionPipe(nn.Module):
"""
This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property.
"""
def __init__(self, source_obj) -> None:
super(DiffusionPipe, self).__init__()
for k, v in source_obj.__dict__.items():
if isinstance(v, nn.Module):
self.add_module(k, v)
else:
setattr(self, k, v)
skip_list = ["_execution_device", "to", "device"] # this
for name, member in inspect.getmembers(source_obj.__class__):
if name in skip_list:
continue
if not name.startswith("__") and not name.endswith("__"):
if isinstance(member, property):
setattr(self.__class__, name, member)
elif inspect.isfunction(member) or inspect.ismethod(member):
bound_method = types.MethodType(member, self)
setattr(self, name, bound_method)
elif not callable(member) and not isinstance(member, property):
setattr(self, name, member)
elif name == "__call__":
bound_method = types.MethodType(member, self)
setattr(self, "_forward", bound_method)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
Accelerate's module hooks.
"""
# return self.device
return torch.device("cuda")
@property
def device(self):
next(self.parameters()).device
def forward(self, *args, **kwargs):
return self._forward(*args, **kwargs)

View File

@ -0,0 +1,220 @@
# Code adapted from:
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
from typing import Callable, List, Optional, Union
import PIL.Image
import torch
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_256_BIN,
ASPECT_RATIO_512_BIN,
ASPECT_RATIO_1024_BIN,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
from colossalai.logging import get_dist_logger
from .diffusion import DiffusionPipe
logger = get_dist_logger(__name__)
@torch.no_grad()
def pixart_alpha_forward(
self: DiffusionPipe,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
use_resolution_binning: bool = True,
max_sequence_length: int = 120,
**kwargs,
) -> PIL.Image:
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if self.transformer.config.sample_size == 128:
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
if do_classifier_free_guidance:
resolution = torch.cat([resolution, resolution], dim=0)
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
# compute previous image: x_t -> x_t-1
if num_inference_steps == 1:
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
else:
image = latents
if not output_type == "latent":
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
# self.maybe_free_model_hooks()
return image

View File

@ -0,0 +1,178 @@
# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
from .diffusion import DiffusionPipe
# TODO(@lry89757) temporarily image, please support more return output
@torch.no_grad()
def sd3_forward(
self: DiffusionPipe,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
prompt_3,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_3=prompt_3,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
return image

View File

@ -1,16 +1,22 @@
from .glide_llama import GlideLlamaModelPolicy
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
from .pixart_alpha import PixArtAlphaInferPolicy
from .stablediffusion3 import StableDiffusion3InferPolicy
model_policy_map = {
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
"glide_llama": GlideLlamaModelPolicy,
"StableDiffusion3Pipeline": StableDiffusion3InferPolicy,
"PixArtAlphaPipeline": PixArtAlphaInferPolicy,
}
__all__ = [
"NoPaddingLlamaModelInferPolicy",
"NoPaddingBaichuanModelInferPolicy",
"GlideLlamaModelPolicy",
"StableDiffusion3InferPolicy",
"PixArtAlphaInferPolicy",
"model_polic_map",
]

View File

@ -0,0 +1,34 @@
from torch import nn
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
from colossalai.shardformer.policies.base_policy import Policy
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = {}
self.append_or_create_method_replacement(
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
)
return policy
def preprocess(self) -> nn.Module:
return self.model
def postprocess(self):
return self.model
def config_sanity_check(self):
pass
def to_rpc_param(self) -> str:
return __class__.__name__
@staticmethod
def from_rpc_param() -> "PixArtAlphaInferPolicy":
return PixArtAlphaInferPolicy()

View File

@ -0,0 +1,34 @@
from torch import nn
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
from colossalai.shardformer.policies.base_policy import Policy
class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = {}
self.append_or_create_method_replacement(
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
)
return policy
def preprocess(self) -> nn.Module:
return self.model
def postprocess(self):
return self.model
def config_sanity_check(self):
pass
def to_rpc_param(self) -> str:
return __class__.__name__
@staticmethod
def from_rpc_param() -> "StableDiffusion3InferPolicy":
return StableDiffusion3InferPolicy()

View File

@ -2,6 +2,7 @@ import enum
from dataclasses import dataclass
from typing import Any, List
from colossalai.inference.config import DiffusionGenerationConfig
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
@ -46,6 +47,17 @@ class RequestStatus(enum.Enum):
return status == RequestStatus.WAITING
@dataclass
class DiffusionSequence:
"""
parameters for diffusion
"""
request_id: int
prompt: str
generation_config: DiffusionGenerationConfig
@dataclass
class Sequence:
"""Store information of input sequence.

View File

@ -1,13 +1,16 @@
"""
Utils for model inference
"""
import math
import os
import re
from enum import Enum
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
from diffusers import DiffusionPipeline
from torch import nn
from colossalai.logging import get_dist_logger
@ -158,3 +161,36 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
return False
class ModelType(Enum):
DIFFUSION_MODEL = "Diffusion Model"
LLM = "Large Language Model (LLM)"
UNKNOWN = "Unknown Model Type"
def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
if isinstance(model_or_path, DiffusionPipeline):
return ModelType.DIFFUSION_MODEL
elif isinstance(model_or_path, nn.Module):
return ModelType.LLM
elif isinstance(model_or_path, str):
try:
from transformers import AutoConfig
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
return ModelType.LLM
except:
"""
model type is not `ModelType.LLM`
"""
try:
DiffusionPipeline.load_config(model_or_path)
return ModelType.DIFFUSION_MODEL
except:
"""
model type is not `ModelType.DIFFUSION_MODEL`
"""
else:
return ModelType.UNKNOWN

View File

@ -3,6 +3,12 @@
import os
# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when overlapping communication and computation,
# the order of of kernel launches on GPUs are the same as on the CPU so that comm is launched first.
# see https://github.com/NVIDIA/Megatron-LM/issues/533
# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
import torch.distributed as dist
from colossalai.accelerator import get_accelerator

View File

@ -138,9 +138,7 @@ class Initializer_2D(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size
self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
assert (
self.tensor_parallel_size == self.summa_dim**2
), "2D summa dim should equal to tensor parallel size ^ 0.5"
assert self.tensor_parallel_size == self.summa_dim**2, "2D summa dim should equal to tensor parallel size ^ 0.5"
_check_summa_env_var(self.summa_dim)
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)

View File

@ -54,7 +54,6 @@ class RequestTracker:
class Async_Engine:
"""
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
Background loop: inference reqs in waiting list (Listen)

View File

@ -118,16 +118,16 @@ class Batch:
class BatchTokenIdOut:
def __init__(self):
self.reqs_infs: List[
Tuple[str, int, Dict, bool, bool]
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = (
[]
) # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
class BatchStrOut:
def __init__(self):
self.reqs_infs: List[
Tuple[str, str, Dict, bool, bool]
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = (
[]
) # [req_id, token_str, gen_metadata, finished_state, abort_state]
class AbortReq:

View File

@ -1,6 +1,7 @@
"""
Utils for model inference
"""
import os
import torch

View File

@ -14,6 +14,7 @@ class BatchInferState:
Information to be passed and used for a batch of inputs during
a single model forward
"""
batch_size: int
max_len_in_batch: int

View File

@ -4,6 +4,7 @@ of the ModelTC/lightllm GitHub repository
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
"""
import torch
from transformers.utils import logging

View File

@ -1,6 +1,7 @@
"""
Utils for model inference
"""
import os
import torch

View File

@ -81,7 +81,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

View File

@ -1,20 +1,5 @@
from .checkpoint import MoECheckpointIO
from .experts import MLPExperts
from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
__all__ = [
"MLPExperts",
"MoeRouter",
"Top1Router",
"Top2Router",
"TopKRouter",
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"SparseMLP",
"MoECheckpointIO",
"MOE_MANAGER",
"apply_load_balance",
]

View File

@ -1,792 +0,0 @@
import copy
import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import OptimizerWrapper
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import (
get_dp_group,
get_dp_rank,
get_dp_size,
get_ep_group,
get_ep_rank,
get_ep_size,
is_moe_tensor,
)
class MoECheckpointIO(HybridParallelCheckpointIO):
def __init__(
self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
) -> None:
assert zero_stage in [
0,
1,
2,
], f"zero_stage should be 0 or 1 or 2, got {zero_stage}"
super().__init__(dp_group, pp_group, tp_group, zero_stage)
self.parallel = MOE_MANAGER.parallel
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
"""
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
"""
for name, param in state_dict.items():
if ".experts." in name:
if name in dict(model.named_parameters()):
model_param = dict(model.named_parameters())[name]
if is_moe_tensor(model_param):
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = param.shape[0] // ep_size
assert param.shape[0] % ep_size == 0
param = param[ep_rank * expert_num : (ep_rank + 1) * expert_num]
state_dict[name] = param
dist.barrier()
return state_dict
def _model_sharder(
self,
state_dict: nn.Module,
prefix: str = "",
keep_vars: bool = False,
size_per_shard: int = 1024,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
for name, param in state_dict.items():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None:
state_dict = torch.load(checkpoint)
state_dict = self.pre_load_model(model, state_dict)
model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
state_dict = self.pre_load_model(model, state_dict)
missing_keys = []
load_state_dict_into_model(
model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True,
)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
_load(name)
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def pre_save_model(self, model: nn.Module) -> dict:
state_dict = model.state_dict()
for name, param in model.named_parameters():
if ".experts." in name and is_moe_tensor(param):
ep_group = get_ep_group(param)
ep_rank = get_ep_rank(param)
ep_size = get_ep_size(param)
dp_rank = get_dp_rank(param)
if dp_rank == 0:
param = param.data.cuda()
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu()
if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.all_gather_object(out, state_dict, group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:
new_state_dict.update(o)
state_dict = new_state_dict
dist.barrier()
return state_dict
def save_unsharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
):
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
dist.barrier()
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
The filenames are in the form of "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
torch.cuda.empty_cache()
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
return
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
dist.barrier()
torch.cuda.empty_cache()
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# ========================================================
def pre_load_optim(
self,
state: OrderedDict,
working_param,
current_shape: torch.Size,
original_shape: torch.Size,
device: torch.device,
inplace: bool,
) -> OrderedDict:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
Args:
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
current_shape (torch.Size): The size of parameter after sharding.
original_shape (torch.Size): The size of parameter before sharding.
device (torch.device): The destination device of loaded optimizer states.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
Returns:
OrderedDict: The sharded optimizer state of the given parameter.
"""
state_ = state if inplace else copy.deepcopy(state)
is_moe_tensor_flag = is_moe_tensor(working_param)
if is_moe_tensor_flag:
ep_rank = get_ep_rank(working_param)
ep_size = get_ep_size(working_param)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
if is_moe_tensor_flag:
with torch.no_grad():
expert_num = v.shape[0] // ep_size
assert v.shape[0] % ep_size == 0
v = v[ep_rank * expert_num : (ep_rank + 1) * expert_num]
else:
# Shard state along data parallel group when using Zero.
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)
return state_
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
id_map[param_id] = param
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
updated_groups.append(new_pg)
# ep param group
if len(optimizer.optim.param_groups) > len(saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
# If this param's states has been loaded before, directly return.
if filename in loaded_file:
continue
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
# Then shard the loaded optimizer states if using tp/zero.
for pid, state in list(state_dict.items()):
if pid in id_map:
param = id_map[pid]
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif (
hasattr(optimizer, "moe_master_to_working_map")
and id(param) in optimizer.moe_master_to_working_map
):
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
working_param,
current_shape=working_param.shape,
original_shape=original_shape,
device="cpu",
inplace=True,
)
state_dict[pid] = sharded_state
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
dist.barrier()
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Load optimizer from a file with given path.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the checkpoint file.
"""
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
if id(working_param) in optimizer.param_info["param2id"]:
return optimizer.param_info["param2id"][id(working_param)]
else:
None
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
state_dict = load_state_dict(checkpoint)
# Load param_groups.
updated_groups = []
saved_groups = state_dict["param_groups"]
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
updated_groups.append(new_pg)
# ep extra group
if MOE_MANAGER.parallel == "EP":
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1][
"params"
] # Only keep the parameters kept by current pipeline stage.
for param in new_pg["params"]:
param.data = param.data.to(torch.float32)
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
master_to_working_map = optimizer.get_master_to_working_map()
id_map = {}
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id is not None:
id_map[param_id] = param
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
if param is None:
continue
device = param.device
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
param,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True,
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
dist.barrier()
def pre_save_optim(
self,
state: OrderedDict,
param: torch.Tensor,
inplace: bool,
device: torch.device = torch.device("cpu"),
) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
Args:
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
if is_moe_tensor(param):
moe_dp_group = get_dp_group(param)
moe_dp_size = get_dp_size(param)
moe_ep_group = get_ep_group(param)
moe_ep_size = get_ep_size(param)
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# moe param
if is_moe_tensor(param):
# dp gather
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
dist.all_gather(gather_tensor, v, group=moe_dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
# ep gather
gather_tensor = [torch.zeros_like(v) for _ in range(moe_ep_size)]
dist.all_gather(gather_tensor, v, group=moe_ep_group)
v = torch.cat(gather_tensor, dim=0)
else:
# global dp
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dist.get_world_size(self.dp_group))]
dist.all_gather(gather_tensor, v, group=self.dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
state_[k] = v.detach().clone().to(device)
return state_
def _optimizer_sharder(
self,
optimizer: OptimizerWrapper,
size_per_shard: int = 1024,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
master_to_working_map = optimizer.get_master_to_working_map()
for param, state in optimizer.optim.state.items():
if param is None:
continue
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
param_id = param_info["param2id"][id(working_param)]
state_ = self.pre_save_optim(
state,
working_param,
inplace=False,
device=torch.device("cuda"),
)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files that store state tensors of optimizers.
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
torch.cuda.empty_cache()
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if not self.use_zero and self.dp_rank != 0:
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = self._optimizer_sharder(
optimizer,
size_per_shard=size_per_shard,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if control_saving:
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if control_saving:
assert (
self.dp_rank == 0 and self.tp_rank == 0
), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for param_id, state_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(param_id, state_filename)
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
torch.cuda.empty_cache()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer state dict to a file with given path.
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
checkpoint (str): Path to save optimizer state_dict.
gather_dtensor (bool): Whether to gather_dtensor, not used.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
# optimizer states of parameters kept by local device('s pipeline stage)
local_states = dict()
for param, state in optimizer.optim.state.items():
if param is None:
continue
# working param is needed for obtaining correct param_id
master_to_working_map = optimizer.get_master_to_working_map()
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
# gather complete state from tp shards & dp shards
param_id = optimizer.param_info["param2id"][id(working_param)]
local_states[param_id] = self.pre_save_optim(
state,
working_param,
inplace=False,
device=torch.device("cuda"),
)
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
states_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group)
dist.all_gather_object(states_list, local_states, self.pp_group)
# Only the master rank do the saving.
if self.coordinator.is_master():
state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
dist.barrier()

View File

@ -7,8 +7,8 @@ from torch import Tensor, nn
from torch.distributed import ProcessGroup
from colossalai.cluster import ProcessGroupMesh
from colossalai.moe.experts import MLPExperts
from colossalai.moe.manager import MOE_MANAGER
from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.zero.low_level import LowLevelZeroOptimizer
@ -292,7 +292,7 @@ class LoadBalancer:
exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"]
else:
master_weight_ptr = optim._param_store.working_to_master_param[id(weight)]
master_weight_ptr = optim.working_to_master_param[id(weight)]
working_weight_ptr = weight
exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"]
@ -344,7 +344,7 @@ class LoadBalancer:
# gate optim should be obtained first
gate_shape = self.gate.shape
# get master weight and optim
master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)]
master_gate_weight = optim.working_to_master_param[id(self.gate)]
gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"]
gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"]
# gather

View File

@ -1,78 +0,0 @@
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from colossalai.moe.manager import MOE_MANAGER
class MoeCrossEntropyLoss(_Loss):
r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
Args:
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
The ``args`` and ``kwargs`` should include parameters below:
::
weight (Tensor, optional)
size_average (bool, optional)
ignore_index (int, optional)
reduce (bool, optional)
reduction (str, optional)
label_smoothing (float, optional)
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
"""
def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
super().__init__()
self.loss = nn.CrossEntropyLoss(*args, **kwargs)
self.aux_weight = aux_weight
def forward(self, *args):
"""
The ``args`` should at least include parameters below:
::
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
"""
main_loss = self.loss(*args)
aux_loss = MOE_MANAGER.get_loss()
return main_loss + self.aux_weight * aux_loss
class MoeLoss(_Loss):
"""A wrapper class for any loss module to add with auxiliary loss.
Args:
aux_weight (float): Weight of auxiliary loss in total loss.
loss_fn (``Callable``): Loss function.
args (list): Args in loss function.
kwargs (dict): Kwargs in loss function
"""
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
super().__init__()
self.loss_fn = loss_fn(*args, **kwargs)
self.aux_weight = aux_weight
def forward(self, *args, **kwargs):
"""
The ``args`` and ``kwargs`` should at least include parameters below:
::
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
Note:
The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
"""
main_loss = self.loss_fn(*args, **kwargs)
aux_loss = MOE_MANAGER.get_loss()
return main_loss + self.aux_weight * aux_loss

View File

@ -1,466 +0,0 @@
import math
from abc import ABC
from typing import Callable, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from colossalai.moe._operation import moe_cumsum
from colossalai.moe.manager import MOE_MANAGER
class MoeRouter(nn.Module, ABC):
"""Base class for all MoE routers.
Args:
k_value (int): The value of top_k.
capacity_factor_train (float): Capacity factor in routing of training.
capacity_factor_eval (float): Capacity factor in routing of evaluation.
min_capacity (int): The minimum number of the capacity of each expert.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(
self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
use_kernel: bool = False,
):
super().__init__()
self.k_value = k_value
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.noisy_func = noisy_func
self.drop_tks = drop_tks
self._aux_loss = None
self._z_loss = None
self.use_kernel = use_kernel
def get_capacity(self, num_tokens, num_experts, ep_group=None):
if ep_group is not None:
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
dist.all_reduce(num_tokens_tensor, group=ep_group)
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return int(capacity)
def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None:
"""Computes auxiliary load balancing loss as in Switch Transformer.
See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
implements the loss function presented in equations (4) - (6). It aims to
penalize those cases where the routing between experts is unbalanced.
Args:
router_probs: Probability assigned to each expert per token. Shape:
<float32>[num_groups, tokens_per_group, num_experts].
expert_indices: <int>[num_groups, tokens_per_group, num_selected_experts]
indices identifying the top num_selected_experts for a given token.
"""
assert self._aux_loss is None
if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0)
assert (
router_probs.dim() == expert_indices.dim() == 3
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts)
# For a given token, determine if it was routed to a given expert.
# Shape: [num_groups, tokens_per_group, num_experts]
expert_mask = expert_mask.max(dim=-2)[0]
tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
self._aux_loss = aux_loss
def set_z_loss(self, router_logits: torch.Tensor):
"""Compute router z-loss.
The router z-loss was introduced in Designing Effective Sparse Expert Models
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
small in an effort to improve stability.
Args:
router_logits: <float>[num_groups, tokens_per_group, num_experts] router logits.
"""
assert self._z_loss is None
if router_logits.dim() == 2:
router_logits = router_logits.unsqueeze(0)
assert router_logits.dim() == 3, "router_logits must be 3D tensor"
num_groups, tokens_per_group, _ = router_logits.shape
log_z = torch.logsumexp(router_logits, dim=-1)
z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group)
self._z_loss = z_loss
def pop_router_loss(self) -> torch.Tensor:
assert self._aux_loss is not None
MOE_MANAGER.add_loss(self._aux_loss, self._z_loss)
self._aux_loss = None
self._z_loss = None
class Top1Router(MoeRouter):
"""Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about Switch Transformer of Google.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert.
select_policy (str, optional): The policy about tokens selection.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample
def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_loss: bool = False,
use_norm: bool = False,
) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
num_tokens = inputs.size(0)
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
# calculate router loss
self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
if self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
elif self.select_policy == "first":
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
mask = mask * torch.lt(ranks, capacity)
else:
raise NotImplementedError("Not support such select policy yet.")
ranks = torch.sum(mask * ranks, dim=-1)
used_capacity = mask.sum(dim=0)
if use_kernel:
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return used_capacity, probs, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return used_capacity, combine_weights, sec_mask, probs
class Top2Router(MoeRouter):
"""Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about ViT-MoE.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_norm: bool = False,
use_loss: bool = True,
) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
if use_norm:
routing_weights, _ = torch.topk(probs, 2, dim=-1)
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
num_experts = probs.size(-1)
num_tokens = inputs.size(0)
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(probs, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = mask1 + mask2 # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss
if use_loss:
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
mask1 *= torch.lt(rank1, capacity)
mask2 *= torch.lt(rank2, capacity)
used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0)
rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1)
if use_kernel:
mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1)
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
return used_capacity, probs, mask, dest_idx, num_experts * capacity
else:
"""
The following code is equivalent to:
```
weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
cb_weight = cb_weight1 + cb_weight2
sec_mask = cb_weight.bool()
```
"""
weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
indices = torch.arange(0, inputs.shape[0], device=inputs.device)
cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]
return used_capacity, cb_weight, sec_mask
class TopKRouter(MoeRouter):
"""Masked matmul router using tokens choose top-k experts assignment.
NOTE: this is modified from flaxformer.
This router uses the same mechanism as in Switch Transformer
(https://arxiv.org/abs/2101.03961) and V-MoE
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
sorted by router_probs and then routed to their choice of expert until the
expert's expert_capacity is reached. There is no guarantee that each token is
processed by an expert, or that each expert receives at least one token.
Attributes:
num_selected_experts: Maximum number of experts to which each token is
routed. Tokens may be routed to fewer experts if particular experts are
oversubscribed / reach capacity.
"""
def __init__(
self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
)
def forward(
self,
router_probs: torch.Tensor,
expert_capacity: int,
) -> Tuple:
"""Computes masks for the top-k experts per token.
Args:
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
probabilities used to determine the routing of tokens to the experts.
Returns:
Dispatch and combine arrays for routing with masked matmuls.
"""
# TODO: FIXME: add parallel group
num_groups, _, num_experts = router_probs.shape
# Top-k router probability and corresponding expert indices for each token.
# Shape: [num_groups, tokens_per_group, num_selected_experts].
expert_gate, expert_index = torch.topk(router_probs, self.k_value)
self.set_aux_loss(router_probs, expert_index, num_experts)
self.pop_router_loss()
# Make num_selected_experts the leading axis to ensure that top-1 choices
# have priority over top-2 choices, which have priority over top-3 choices,
# etc.
expert_index = torch.transpose(expert_index, 1, 2)
# Shape: [num_groups, num_selected_experts * tokens_per_group]
expert_index = expert_index.reshape(num_groups, -1)
# Create mask out of indices.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
# Experts have a fixed capacity that we cannot exceed. A token's priority
# within the expert's buffer is given by the masked, cumulative capacity of
# its target expert.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
# Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
token_priority = torch.transpose(token_priority, 1, 2)
# For each token, across all selected experts, select the only non-negative
# (unmasked) priority. Now, for group G routing to expert E, token T has
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
# is its targeted expert.
# Shape: [num_groups, tokens_per_group, num_experts].
token_priority = torch.max(token_priority, dim=2)[0]
# Token T can only be routed to expert E if its priority is positive and
# less than the expert capacity. One-hot matrix will ignore indices outside
# the range [0, expert_capacity).
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
return combine_array, dispatch_mask
def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter:
if not grouped:
if top_k == 1:
return Top1Router
elif top_k == 2:
return Top2Router
else:
raise NotImplementedError("top_k > 2 is not supported yet")
else:
return TopKRouter

View File

@ -6,10 +6,11 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.distributed_c10d import get_process_group_ranks
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
from colossalai.tensor.moe_tensor.api import is_moe_tensor
class ForceFP32Parameter(torch.nn.Parameter):
@ -145,7 +146,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
if not is_moe_tensor(param):
ep_size = 1 # set ep_size to 1 for dp parameters
else:
ep_size = get_ep_size(param)
ep_size = dist.get_world_size(param.ep_group)
if ep_size not in epsize_param_dict:
epsize_param_dict[ep_size] = []
epsize_param_dict[ep_size].append(param)
@ -170,8 +171,8 @@ def sync_moe_model_param(model: nn.Module):
# When ep_size = world_size, communication is not needed
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
for param in param_dict[ep_size]:
src_rank = get_dp_group_ranks(param)[0]
dist.broadcast(param, src=src_rank, group=get_dp_group(param))
src_rank = get_process_group_ranks(param.dp_group)[0]
dist.broadcast(param, src=src_rank, group=param.dp_group)
def set_moe_args(config: Any, args: dict):

View File

@ -1,17 +1,25 @@
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
import importlib.metadata
import logging
import torch
import torch.nn as nn
from packaging.version import Version
from .bnb_config import BnbQuantizationConfig
try:
import bitsandbytes as bnb
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
try:
# in case lower version of bitsandbytes does not have __version__ attribute
BNB_VERSION = Version(bnb.__version__)
except AttributeError:
BNB_VERSION = Version(importlib.metadata.version("bitsandbytes"))
IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0")
IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2")
except ImportError:
pass

View File

@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
@ -18,6 +18,7 @@ __all__ = [
"DropoutForParallelInput",
"DropoutForReplicatedInput",
"cross_entropy_1d",
"dist_cross_entropy",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",

View File

@ -2,8 +2,11 @@ import torch
import torch.distributed as dist
from torch.autograd import Function
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
__all__ = ["DistCrossEntropy", "cross_entropy_1d"]
from colossalai.shardformer.shard import ShardConfig
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
class DistCrossEntropy(Function):
@ -132,3 +135,43 @@ def cross_entropy_1d(
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
def dist_cross_entropy(
labels: torch.Tensor,
logits: torch.Tensor,
shard_config: ShardConfig,
out_features: int,
vocab_size: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Helper to compute cross entropy loss for most shardformer models,
compatible with PP, TP and SP.
"""
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
# Cross entropy with all-reduce for TP
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features,
dtype=dtype,
)
else:
# NOTE if use TP and not parallel_output, the output is gathered.
# see VocabParallelLMHead1D
shift_logits = shift_logits.view(-1, vocab_size)
loss = loss_fct(shift_logits, shift_labels)
return loss

View File

@ -0,0 +1,3 @@
from .experts import *
from .layers import *
from .routers import *

View File

@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
@ -35,7 +35,7 @@ class MLPExperts(nn.Module):
num_experts: int,
hidden_size: int,
intermediate_size: int,
expert_parallel: Optional[str] = None,
expert_parallel: Optional[str] = "EP",
activation: Optional[Callable] = None,
drop_rate: Optional[float] = 0,
gated: Optional[bool] = False,

View File

@ -8,11 +8,9 @@ import torch.nn as nn
import torch.nn.functional as F
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe.experts import MLPExperts
from colossalai.moe.load_balance import LoadBalancer
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
@ -23,6 +21,7 @@ class SparseMLP(nn.Module):
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
parallel (str): parallel mode. Should be "EP", "TP" or None
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
@ -51,6 +50,7 @@ class SparseMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
parallel: str = "EP",
router_loss: bool = True,
router_norm: bool = False,
router_capacity_factor_train: float = 1.25,
@ -66,7 +66,7 @@ class SparseMLP(nn.Module):
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False,
enable_hierarchical_comm: bool = True,
return_gate_logits: bool = False,
):
super().__init__()
@ -77,7 +77,9 @@ class SparseMLP(nn.Module):
self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel()
# self.expert_parallel = MOE_MANAGER.get_parallel()
assert parallel in ["EP", "TP", None], "parallel mode must be EP, TP or None"
self.parallel = parallel
self.router_loss = router_loss
self.router_norm = router_norm
@ -99,7 +101,7 @@ class SparseMLP(nn.Module):
# moe experts
self.experts = MLPExperts(
num_experts=self.num_experts,
expert_parallel=self.expert_parallel,
expert_parallel=self.parallel,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
activation=mlp_activation,
@ -108,11 +110,12 @@ class SparseMLP(nn.Module):
)
# get parallel settings
if self.expert_parallel is not None:
if self.parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
self.ep_hierarchical_group = None
if enable_hierarchical_comm:
# TODO: move to plugin
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
get_ep_group_ranks(self.experts)
)
@ -186,11 +189,11 @@ class SparseMLP(nn.Module):
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
if self.parallel == "EP":
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel == "TP":
elif self.parallel == "TP":
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel is None:
elif self.parallel is None:
expert_output = self._local_process(dispatch_data)
else:
raise NotImplementedError(

View File

@ -0,0 +1,161 @@
import math
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
class MLPExperts(nn.Module):
"""
SparseMLP is a multi-layer perceptron with sparse expert parallel layers.
Args:
num_experts (int): The number of experts
hidden_size (int): The hidden size of MLP
intermediate_size (int): The intermediate size of MLP
expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP.
activation (optional): The activation function of MLP
drop_rate (float, optional): The drop rate of MLP
gated (bool, optional): Whether to use gated MLP
use_kernel (bool, optional): Whether to use kernel optimization
"""
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
expert_parallel: Optional[str] = "EP",
activation: Optional[Callable] = None,
drop_rate: Optional[float] = 0,
gated: Optional[bool] = False,
use_kernel: Optional[bool] = False,
):
super().__init__()
assert expert_parallel in ["EP", "TP", None]
self.expert_parallel = expert_parallel
self.num_total_experts = num_experts
self.gated = gated
self.use_kernel = use_kernel
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
# get expert parallel info
if expert_parallel is not None:
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
num_experts, use_tp=True if expert_parallel == "TP" else False
)
# get settings for different parallel
self.ep_size = get_ep_size(self)
if expert_parallel == "TP":
intermediate_size = intermediate_size // self.ep_size
num_experts = self.num_total_experts
else:
num_experts = self.num_local_experts
else:
self.num_local_experts = self.num_total_experts
self.ep_size = 1
if gated:
self.wi_gate = nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
)
)
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))
self.act_name = activation
self.act = get_activation(activation)
self.drop = nn.Dropout(p=drop_rate)
if expert_parallel is not None:
for param in self.parameters():
set_moe_tensor_info(param, self.moe_info)
# init param
self.reset_parameters()
@torch.no_grad()
def reset_parameters(self):
# expert param should be different
if self.expert_parallel is not None:
seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True)
else:
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
with seed_ctx:
if self.gated:
torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size))
torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size))
else:
torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size))
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
def forward(
self,
x: torch.Tensor,
param_slice: Tuple[slice] = (slice(None),),
use_sparse: bool = True,
) -> torch.Tensor:
"""
forward: hidden_size --> intermediate_size --> hidden_size
Args:
x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)
Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
"""
x = MoeInGradScaler.apply(x, self.ep_size)
e = x.size(1)
h = x.size(-1)
x = x.transpose(0, 1)
inshape = x.shape
x = x.reshape(e, -1, h)
if self.use_kernel and use_sparse:
seq_len = x.shape[1]
with torch.no_grad():
mask = x[:, :, 0] != 0.0
mask = torch.sum(mask, dim=-1)
x_list = []
for i in range(e):
x_list.append(x[i, : mask[i]])
x = x_list
if self.gated:
x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)]
x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)]
if self.use_kernel and HAS_TRITON and self.act_name == "swiglu":
x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)]
else:
x = [self.act(x_gate[i]) * x_up[i] for i in range(e)]
else:
x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)]
x = [self.act(x[i]) for i in range(e)]
x = [self.drop(x[i]) for i in range(e)]
x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)]
if self.use_kernel and use_sparse:
for i in range(e):
x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0)
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
x = MoeOutGradScaler.apply(x, self.ep_size)
return x

View File

@ -1,4 +1,3 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
@ -1005,115 +1004,6 @@ class BertPipelineForwards:
return {"hidden_states": hidden_states}
def get_bert_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.bert.modeling_bert import BertAttention
def forward(
self: BertAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
final_attention_mask = None
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
final_attention_mask = relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
final_attention_mask = relative_position_scores_query + relative_position_scores_key
scale = 1 / math.sqrt(self.attention_head_size)
if attention_mask is not None:
if final_attention_mask != None:
final_attention_mask = final_attention_mask * scale + attention_mask
else:
final_attention_mask = attention_mask
if final_attention_mask is not None:
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
tgt_len = key_layer.size()[2]
final_attention_mask = final_attention_mask.expand(
batch_size, self.num_attention_heads, src_len, tgt_len
).contiguous()
query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
value_layer = value_layer.permute(0, 2, 1, 3).contiguous()
context_layer = me_attention(
query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale
)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, None)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
return forward
def get_jit_fused_bert_self_output_forward():
from transformers.models.bert.modeling_bert import BertSelfOutput

View File

@ -28,7 +28,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer import dist_cross_entropy
logger = logging.get_logger(__name__)
@ -359,30 +359,14 @@ class BloomPipelineForwards:
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states).contiguous()
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels.view(-1))
loss = dist_cross_entropy(
labels,
lm_logits,
shard_config,
self.lm_head.out_features,
self.config.vocab_size,
self.transformer.dtype,
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
@ -714,93 +698,6 @@ class BloomPipelineForwards:
return {"hidden_states": hidden_states}
def get_bloom_flash_attention_forward(enable_jit_fused=False):
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.bloom.modeling_bloom import BloomAttention
def forward(
self: BloomAttention,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, tgt_len, _, _ = query_layer.size()
_, kv_length, _, _ = key_layer.size()
proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
query_layer = query_layer.contiguous().view(*proj_shape)
key_layer = key_layer.contiguous().view(*proj_shape)
value_layer = value_layer.contiguous().view(*proj_shape)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
tgt_len = key_layer.size()[1]
attention_numerical_mask = torch.zeros(
(batch_size, self.num_heads, tgt_len, kv_length),
dtype=torch.float32,
device=query_layer.device,
requires_grad=True,
)
attention_numerical_mask = (
attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
)
attention_numerical_mask = torch.masked_fill(
attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min
)
attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype)
context_layer = me_attention(
query_layer,
key_layer,
value_layer,
attn_bias=attention_numerical_mask,
scale=self.inv_norm_factor,
p=self.attention_dropout.p,
)
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# TODO to replace with the bias_dropout_add function in jit
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, present, None)
return outputs
return forward
def get_jit_fused_bloom_attention_forward():
from transformers.models.bloom.modeling_bloom import BloomAttention
@ -1127,24 +1024,10 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -33,6 +33,7 @@ This license shall be governed and construed in accordance with the laws of Peop
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
"""
""" PyTorch ChatGLM model. """
import copy

View File

@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
@ -25,7 +24,7 @@ from colossalai.shardformer.layer._operation import (
)
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d
from ..layer import ColoAttention, dist_cross_entropy
class CommandPipelineForwards:
@ -300,29 +299,9 @@ class CommandPipelineForwards:
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]
@ -658,24 +637,14 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
loss = dist_cross_entropy(
labels,
logits,
shard_config,
self.lm_head.out_features,
self.config.vocab_size,
self.model.dtype,
)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -0,0 +1,429 @@
from typing import List, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
# copied from modeling_deepseek.py
class AddAuxiliaryLoss(torch.autograd.Function):
"""
The trick function of adding auxiliary (aux) loss,
which includes the gradient of the aux loss during backpropagation.
"""
@staticmethod
def forward(ctx, x, loss):
assert loss.numel() == 1
ctx.dtype = loss.dtype
ctx.required_aux_loss = loss.requires_grad
return x
@staticmethod
def backward(ctx, grad_output):
grad_loss = None
if ctx.required_aux_loss:
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
return grad_output, grad_loss
class EPDeepseekMoE(nn.Module):
def __init__(self):
super(EPDeepseekMoE, self).__init__()
def setup_ep(self, ep_group: ProcessGroup):
ep_group = ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
p.ep_group = ep_group
@staticmethod
def from_native_module(module: Union["DeepseekMoE", "DeepseekMLP"], *args, **kwargs) -> "EPDeepseekMoE":
LazyInitContext.materialize(module)
if module.__class__.__name__ == "DeepseekMLP":
return module
module.__class__ = EPDeepseekMoE
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
module.setup_ep(kwargs["ep_group"])
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
identity = hidden_states
orig_shape = hidden_states.shape
topk_experts_idx, topk_experts_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # [t0, t1, t2 ...]
hidden_states = hidden_states.repeat_interleave(
self.num_experts_per_tok, dim=0
) # after repeat_interleave: [t0 t0 t1 t1 t2 t2 ... ]
flat_topk_experts_idx = topk_experts_idx.view(-1) # [e0 e1 e2 ...]
# The elements of flat_topk_token_idx are token ids, which are arranged in ascending order of expert ids.
flat_topk_token_idx = flat_topk_experts_idx.argsort()
# Now we adjust the order of the hidden states, also in ascending order of expert id
dispatch_states = hidden_states[flat_topk_token_idx]
input_split_sizes = flat_topk_experts_idx.bincount(minlength=self.num_experts) # [n0, n1, n2, n3]
output_split_sizes = torch.zeros_like(input_split_sizes)
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
expert = self.experts[self.expert_start_idx]
output_states = expert(output_states)
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0: # no token routed to this experts
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert(split_states)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_token_idx = torch.empty_like(flat_topk_token_idx)
recover_token_idx[flat_topk_token_idx] = torch.arange(
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
)
output_hidden_states = dispatch_states[recover_token_idx] # t0 t0 t1 t1 t2 t2
output_hidden_states = output_hidden_states.view(-1, self.num_experts_per_tok, orig_shape[-1])
output_hidden_states = (output_hidden_states * topk_experts_weight[:, :, None]).sum(dim=-2) # (B*S, h)
output_hidden_states = output_hidden_states.view(*orig_shape)
output_hidden_states = AddAuxiliaryLoss.apply(output_hidden_states, aux_loss)
if self.config.n_shared_experts is not None:
output_hidden_states = output_hidden_states + self.shared_experts(identity)
return output_hidden_states
class DeepseekPipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
under pipeline setting.
"""
@staticmethod
def deepseek_model_forward(
self: "DeepseekModel",
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
seq_length_with_past = seq_length
past_key_values_length = 0
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if is_flash_attn_2_available():
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1]
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
output_attentions,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
}
@staticmethod
def deepseek_for_causal_lm_forward(
self: "DeepseekForCausalLM",
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MixtralForCausalLM
>>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = DeepseekPipelineForwards.deepseek_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
return out

View File

@ -25,7 +25,7 @@ from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer import dist_cross_entropy
logger = logging.get_logger(__name__)
@ -372,27 +372,9 @@ class GPT2PipelineForwards:
hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss = loss_fct(shift_logits, shift_labels)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
@ -1282,24 +1264,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]

View File

@ -31,7 +31,7 @@ from colossalai.shardformer.layer._operation import (
)
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d
from ..layer import ColoAttention, dist_cross_entropy
class LlamaPipelineForwards:
@ -86,13 +86,20 @@ class LlamaPipelineForwards:
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
# Support SP + PP
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size
if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
# For correct positions ids. The states will be gather along the seq dim in the attention layer later.
seq_length *= sp_size
past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
@ -101,7 +108,7 @@ class LlamaPipelineForwards:
if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
seq_length_with_past = seq_length + past_seen_tokens
@ -118,7 +125,6 @@ class LlamaPipelineForwards:
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
@ -134,6 +140,13 @@ class LlamaPipelineForwards:
else:
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
# Support SP + PP
if stage_manager.is_first_stage():
if sp_mode in ["ring", "split_gather"]:
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
if self.gradient_checkpointing and self.training and use_cache:
if use_cache:
logger.warning_once(
@ -196,6 +209,10 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
# add hidden states from the last decoder layer
if output_hidden_states:
@ -304,29 +321,9 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]
@ -529,7 +526,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@ -804,24 +800,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

View File

@ -19,7 +19,7 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d
from ..layer import ColoAttention, dist_cross_entropy
logger = logging.get_logger(__name__)
@ -275,29 +275,9 @@ class MistralForwards:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]
@ -708,23 +688,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1,222 +1,108 @@
from functools import partial
from typing import Callable, Dict, List, Optional, Union
from typing import List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import CrossEntropyLoss, Module
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import ProcessGroup
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralModel,
MixtralSparseMoeBlock,
MoeCausalLMOutputWithPast,
_prepare_4d_causal_attention_mask,
load_balancing_loss_func,
)
from transformers.utils import logging
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.shard import ShardConfig
from .mixtral_layer import EPMixtralSparseMoeBlock
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
from colossalai.shardformer.shard.utils import set_tensors_to_none
class MixtralPolicy(Policy):
def config_sanity_check(self):
pass
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config):
self.moe_info = None
super().__init__(config)
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
def setup_ep(self, ep_group: ProcessGroup):
ep_group = ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
p.ep_group = ep_group
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
@staticmethod
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
# if "ep_group" in kwargs:
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
module.setup_ep(kwargs["ep_group"])
return module
return self.model
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
selected_experts = selected_experts.t().reshape(-1)
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key=MixtralDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=MixtralModel,
)
if self.shard_config.enable_flash_attention:
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
else:
module = self.model.model
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
class MixtralModelPolicy(MixtralPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralModel,
new_forward=MixtralPipelineForwards.mixtral_model_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
return []
class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
]
)
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralForCausalLM,
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device
)
dispatch_states = dispatch_states[recover_experts_idx]
k_hidden_states = dispatch_states.chunk(self.top_k)
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
for i in range(1, self.top_k):
output_states += k_hidden_states[i] * routing_weights[:, i, None]
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
return output_states, router_logits
class MixtralPipelineForwards:
@ -332,7 +218,7 @@ class MixtralPipelineForwards:
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if self._use_flash_attention_2:
if is_flash_attn_2_available():
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:

View File

@ -22,7 +22,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer import dist_cross_entropy
logger = logging.get_logger(__name__)
@ -221,7 +221,7 @@ class OPTPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training:
layer_outputs = self._gradient_checkpointing_func(
layer_outputs = self.decoder._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_attention_mask,
@ -330,30 +330,14 @@ class OPTPipelineForwards:
)
if stage_manager.is_last_stage():
logits = self.lm_head(outputs[0]).contiguous()
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.decoder.dtype,
)
else:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
loss = dist_cross_entropy(
labels,
logits,
shard_config,
self.lm_head.out_features,
self.config.vocab_size,
self.model.decoder.dtype,
)
if not return_dict:
output = (logits,) + outputs[1:]
@ -971,26 +955,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
)
logits = self.lm_head(outputs[0]).contiguous()
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.decoder.dtype,
)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -32,7 +32,7 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d
from ..layer import ColoAttention, dist_cross_entropy
class Qwen2PipelineForwards:
@ -168,13 +168,27 @@ class Qwen2PipelineForwards:
next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
@ -198,7 +212,6 @@ class Qwen2PipelineForwards:
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
@ -304,25 +317,9 @@ class Qwen2PipelineForwards:
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]
@ -724,26 +721,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
)
if not return_dict:
output = (logits,) + outputs[1:]

Some files were not shown because too many files have changed in this diff Show More