mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into rlhf_SimPO
commit
16f3451fe2
|
@ -90,7 +90,7 @@ jobs:
|
|||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch
|
||||
timeout-minutes: 90
|
||||
defaults:
|
||||
run:
|
||||
|
@ -165,6 +165,7 @@ jobs:
|
|||
env:
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
||||
- name: Collate artifact
|
||||
env:
|
||||
|
|
|
@ -13,7 +13,7 @@ jobs:
|
|||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Check GPU Availability # ensure all GPUs have enough memory
|
||||
|
@ -69,6 +69,7 @@ jobs:
|
|||
env:
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
||||
- name: Notify Lark
|
||||
id: message-preparation
|
||||
|
|
|
@ -50,7 +50,7 @@ jobs:
|
|||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: ${{ matrix.container }}
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
|
||||
timeout-minutes: 200
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
|
@ -92,3 +92,4 @@ jobs:
|
|||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
|
|
@ -41,7 +41,7 @@ jobs:
|
|||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: ${{ matrix.container }}
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
|
||||
timeout-minutes: 200
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
|
||||
|
@ -87,3 +87,4 @@ jobs:
|
|||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
|
|
@ -38,7 +38,7 @@ jobs:
|
|||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: ${{ matrix.container }}
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
|
||||
timeout-minutes: 200
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
|
@ -85,6 +85,7 @@ jobs:
|
|||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
||||
- name: Notify Lark
|
||||
id: message-preparation
|
||||
|
|
|
@ -1,34 +1,34 @@
|
|||
repos:
|
||||
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: v2.2.1
|
||||
rev: v2.3.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
name: autoflake (python)
|
||||
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
name: sort all imports (python)
|
||||
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 23.9.1
|
||||
rev: 24.4.2
|
||||
hooks:
|
||||
- id: black
|
||||
name: black formatter
|
||||
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v13.0.1
|
||||
rev: v18.1.8
|
||||
hooks:
|
||||
- id: clang-format
|
||||
name: clang formatter
|
||||
types_or: [c++, c]
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
rev: v4.6.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: check-merge-conflict
|
||||
|
|
|
@ -83,15 +83,19 @@ class DataCollatorForSupervisedDataset(object):
|
|||
|
||||
# `List[torch.Tensor]`
|
||||
batch_input_ids = [
|
||||
torch.LongTensor(instance["input_ids"][: self.max_length])
|
||||
if len(instance["input_ids"]) > self.max_length
|
||||
else torch.LongTensor(instance["input_ids"])
|
||||
(
|
||||
torch.LongTensor(instance["input_ids"][: self.max_length])
|
||||
if len(instance["input_ids"]) > self.max_length
|
||||
else torch.LongTensor(instance["input_ids"])
|
||||
)
|
||||
for instance in instances
|
||||
]
|
||||
batch_labels = [
|
||||
torch.LongTensor(instance["labels"][: self.max_length])
|
||||
if len(instance["labels"]) > self.max_length
|
||||
else torch.LongTensor(instance["labels"])
|
||||
(
|
||||
torch.LongTensor(instance["labels"][: self.max_length])
|
||||
if len(instance["labels"]) > self.max_length
|
||||
else torch.LongTensor(instance["labels"])
|
||||
)
|
||||
for instance in instances
|
||||
]
|
||||
if self.tokenizer.padding_side == "right":
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
loss functions
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
reward model
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Training utilities for Coati.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
|
|
@ -78,7 +78,9 @@ def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict
|
|||
option_string = "ABCDEFG"
|
||||
count = len(line["options"])
|
||||
|
||||
input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
|
||||
input = (
|
||||
"问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
|
||||
)
|
||||
|
||||
all_classes = list(option_string[0:count])
|
||||
|
||||
|
@ -150,7 +152,15 @@ def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=F
|
|||
)
|
||||
elif dataset_name in chinese_qa_datasets:
|
||||
question_input = (
|
||||
"问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label)
|
||||
"问题:"
|
||||
+ passage
|
||||
+ " "
|
||||
+ question
|
||||
+ "\n"
|
||||
+ "从以下选项中选择:"
|
||||
+ " ".join(options)
|
||||
+ "\n"
|
||||
+ "答案:{}".format(label)
|
||||
)
|
||||
elif dataset_name in english_cloze_datasets:
|
||||
question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)
|
||||
|
|
|
@ -57,7 +57,11 @@ ceval_subject_mapping = {
|
|||
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
|
||||
"accountant": ["Accountant", "注册会计师", "Other"],
|
||||
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
|
||||
"environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"],
|
||||
"environmental_impact_assessment_engineer": [
|
||||
"Environmental Impact Assessment Engineer",
|
||||
"环境影响评价工程师",
|
||||
"Other",
|
||||
],
|
||||
"tax_accountant": ["Tax Accountant", "税务师", "Other"],
|
||||
"physician": ["Physician", "医师资格", "Other"],
|
||||
}
|
||||
|
|
|
@ -56,9 +56,11 @@ class MTBenchDataset(BaseDataset):
|
|||
"instruction": question["turns"],
|
||||
"input": "",
|
||||
"output": [],
|
||||
"target": [""] * turn_number
|
||||
if question["question_id"] not in reference
|
||||
else reference[question["question_id"]],
|
||||
"target": (
|
||||
[""] * turn_number
|
||||
if question["question_id"] not in reference
|
||||
else reference[question["question_id"]]
|
||||
),
|
||||
}
|
||||
|
||||
if category in dataset["test"]:
|
||||
|
|
|
@ -77,7 +77,9 @@ class HuggingFaceModel(BaseModel):
|
|||
self.indices_for_choices[0].append(
|
||||
self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
|
||||
)
|
||||
self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1])
|
||||
self.indices_for_choices[1].append(
|
||||
self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]
|
||||
)
|
||||
|
||||
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
|
||||
"""
|
||||
|
|
|
@ -1,92 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
|
||||
|
||||
|
||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.setup_ep()
|
||||
|
||||
def setup_ep(self):
|
||||
_, moe_info = MOE_MANAGER.get_info(self.num_experts)
|
||||
ep_group = moe_info.ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_info(p, moe_info)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
module.setup_ep()
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
selected_experts = selected_experts.t().reshape(-1)
|
||||
selected_experts_idx = selected_experts.argsort()
|
||||
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
|
||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
# compute expert output
|
||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||
if output_states.size(0) > 0:
|
||||
if self.num_experts_per_ep == 1:
|
||||
# no need to split
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
|
||||
output_states = expert.w2(output_states)
|
||||
else:
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0:
|
||||
continue
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||
split_states = expert.w2(split_states)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||
selected_experts_idx.size(0), device=selected_experts_idx.device
|
||||
)
|
||||
dispatch_states = dispatch_states[recover_experts_idx]
|
||||
k_hidden_states = dispatch_states.chunk(self.top_k)
|
||||
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
|
||||
for i in range(1, self.top_k):
|
||||
output_states += k_hidden_states[i] * routing_weights[:, i, None]
|
||||
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return output_states, router_logits
|
|
@ -2,8 +2,6 @@ import argparse
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
|
||||
|
@ -70,8 +68,6 @@ def main():
|
|||
ep_size=ep_size,
|
||||
zero_stage=1,
|
||||
precision=args.precision,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
enable_fused_normalization=args.use_layernorm_kernel,
|
||||
enable_jit_fused=args.use_kernel,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
NUM_GPU=2
|
||||
MODEL="mistralai/Mixtral-8x7B-v0.1"
|
||||
# MODEL="mistralai/Mixtral-8x7B-v0.1"
|
||||
MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
|
||||
# ep
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
|
||||
|
|
|
@ -1,146 +0,0 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from torch.optim import Adam
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def check_model_equal(model1, model2):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
assert torch.equal(p1.half(), p2.half())
|
||||
|
||||
|
||||
def get_optimizer_snapshot(optim):
|
||||
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
|
||||
param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [id(p) for p in group["params"]]
|
||||
new_group = {"params": params}
|
||||
for k, v in group.items():
|
||||
if k != "params":
|
||||
new_group[k] = v
|
||||
param_groups.append(new_group)
|
||||
return {
|
||||
"state": state,
|
||||
"param_groups": param_groups,
|
||||
}
|
||||
|
||||
|
||||
def check_optimizer_snapshot_equal(snapshot1, snapshot2):
|
||||
# check param_groups
|
||||
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
|
||||
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
|
||||
assert set(group1.keys()) == set(group2.keys())
|
||||
for k in group1.keys():
|
||||
assert group1[k] == group2[k]
|
||||
# check state
|
||||
assert set(snapshot1["state"].keys()) == set(
|
||||
snapshot2["state"].keys()
|
||||
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
|
||||
for pid in snapshot1["state"].keys():
|
||||
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
|
||||
assert set(state1.keys()) == set(state2.keys())
|
||||
for k in state1.keys():
|
||||
if isinstance(state1[k], torch.Tensor):
|
||||
assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
|
||||
else:
|
||||
assert state1[k] == state2[k]
|
||||
|
||||
|
||||
def check_mixtral_moe_layer():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
config = MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
|
||||
orig_model = MixtralForCausalLM(config).cuda()
|
||||
model = deepcopy(orig_model)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
ep_size=2,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
microbatch_size=1,
|
||||
zero_stage=1,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
|
||||
# initialize grads
|
||||
data_iter = iter(
|
||||
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
|
||||
)
|
||||
booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
lambda outputs, inputs: outputs.loss,
|
||||
optimizer,
|
||||
)
|
||||
|
||||
# check save model
|
||||
booster.save_model(model, "mixtral_model", shard=True)
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
|
||||
check_model_equal(orig_model, saved_model)
|
||||
saved_model.save_pretrained("mixtral_hf_model")
|
||||
dist.barrier()
|
||||
|
||||
# check load model
|
||||
new_model = MixtralForCausalLM(config).cuda()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
||||
booster.load_model(new_model, "mixtral_hf_model")
|
||||
check_model_equal(model, new_model)
|
||||
|
||||
# check save optimizer
|
||||
optimizer.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
|
||||
dist.barrier()
|
||||
# reset optimizer state
|
||||
for state in optimizer.unwrap().state.values():
|
||||
for v in state.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v.zero_()
|
||||
booster.load_optimizer(optimizer, "mixtral_optim")
|
||||
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
|
||||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch(rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
def test_mixtral_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixtral_moe_layer(4)
|
|
@ -2,13 +2,11 @@ import argparse
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralForCausalLM
|
||||
from utils import load_checkpoint, move_to_cuda, save_checkpoint
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
|
@ -155,12 +153,10 @@ def main():
|
|||
pp_size=args.pp_size,
|
||||
ep_size=args.ep_size,
|
||||
microbatch_size=args.microbatch_size,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
enable_fused_normalization=args.use_layernorm_kernel,
|
||||
enable_jit_fused=args.use_kernel,
|
||||
precision=args.precision,
|
||||
zero_stage=args.zero_stage,
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
|
@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
|
|||
https://github.com/langchain-ai/langchain
|
||||
The original code is licensed under the MIT license.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
|
|
|
@ -8,6 +8,7 @@ This code is based on LangChain Ai's langchain, which can be found at
|
|||
https://github.com/langchain-ai/langchain
|
||||
The original code is licensed under the MIT license.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
|
|||
https://github.com/langchain-ai/langchain
|
||||
The original code is licensed under the MIT license.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, List
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
|
||||
"""
|
||||
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ resp = llm(prompt='What do you call a three-ton kangaroo?', **gen_config)
|
|||
print(resp) # super-heavyweight awesome-natured yawning Australian creature!
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料
|
|||
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
import torch
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Generation utilities
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Implement a memory class for storing conversation history
|
||||
Support long term and short term memory
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from colossalqa.chain.memory.summary import ConversationSummaryMemory
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Class for logging with extra control for debugging
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Multilingual retrieval based conversation system
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from colossalqa.data_loader.document_loader import DocumentLoader
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Code for custom retriver with incremental update
|
||||
"""
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Code for Chinese text splitter
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from colossalqa.text_splitter.utils import get_cleaned_paragraph
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for English retrieval based conversation system backed by LLaMa2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for English retrieval based conversation system backed by LLaMa2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for English retrieval based conversation system backed by LLaMa2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
|
|
@ -107,20 +107,22 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias
|
||||
else compute_size_in_bytes(weight_tensor),
|
||||
parameter=(
|
||||
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
|
||||
),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
|
||||
if has_bias
|
||||
else compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias
|
||||
else compute_size_in_bytes(weight_tensor),
|
||||
activation=(
|
||||
compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
|
||||
if has_bias
|
||||
else compute_size_in_bytes([input_tensor, weight_tensor])
|
||||
),
|
||||
parameter=(
|
||||
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
|
||||
),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
|
|
@ -1,10 +1,18 @@
|
|||
from .gemini_plugin import GeminiPlugin
|
||||
from .hybrid_parallel_plugin import HybridParallelPlugin
|
||||
from .low_level_zero_plugin import LowLevelZeroPlugin
|
||||
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from .plugin_base import Plugin
|
||||
from .torch_ddp_plugin import TorchDDPPlugin
|
||||
|
||||
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
|
||||
__all__ = [
|
||||
"Plugin",
|
||||
"TorchDDPPlugin",
|
||||
"GeminiPlugin",
|
||||
"LowLevelZeroPlugin",
|
||||
"HybridParallelPlugin",
|
||||
"MoeHybridParallelPlugin",
|
||||
]
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
|
|
@ -655,7 +655,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.dp_pg = dp_process_group
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
|
@ -718,7 +717,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
"""Retrieve all working gradients from different parameter groups."""
|
||||
all_working_grads = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
||||
working_grads = self.get_working_grads_by_group_id(group_id)
|
||||
all_working_grads.extend(working_grads)
|
||||
return all_working_grads
|
||||
|
||||
|
@ -726,7 +725,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
"""Identify gradients to be synchronized in the sequence parallelism."""
|
||||
grads_to_sync = []
|
||||
for grad in all_working_grads:
|
||||
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
|
||||
param_id_for_grad = self.get_param_id_for_grad(grad)
|
||||
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
|
||||
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
|
||||
grads_to_sync.append(grad)
|
||||
|
@ -739,7 +738,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
# Get all working gradients and gradients to be synchronized.
|
||||
all_working_grads = _get_all_working_grads()
|
||||
grads_to_sync = _get_grads_to_sync(all_working_grads)
|
||||
if self._grad_store.require_grad_sync and grads_to_sync is not None:
|
||||
if self.require_grad_sync and grads_to_sync is not None:
|
||||
# Synchronize sequence parallelism gradients if required.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
|
||||
else:
|
||||
|
@ -763,7 +762,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, retain_graph)
|
||||
|
||||
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self._sync_sp_grads()
|
||||
else:
|
||||
|
@ -788,14 +787,14 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
# Call the superclass backward_by_grad method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self._sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
|
@ -811,7 +810,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
if len(gradients) == 0:
|
||||
return 0.0
|
||||
|
||||
dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1
|
||||
dp_size = get_world_size(dp_pg) if dp_pg is not None else 1
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
@ -842,7 +841,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
|
||||
param_id_for_grad = self.get_param_id_for_grad(grad)
|
||||
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
|
||||
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
|
@ -856,7 +855,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_shared_param = shared_param[self.stage_manager.stage]
|
||||
working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param))
|
||||
working_grad = self.get_working_grad_by_param_id(id(stage_shared_param))
|
||||
if grad is working_grad:
|
||||
grad_norm_exponentiated /= len(shared_param)
|
||||
|
||||
|
@ -867,7 +866,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
)
|
||||
if dp_size > 1:
|
||||
# compute norm in dp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg)
|
||||
if tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
@ -1206,6 +1205,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
# sync gradients across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
else:
|
||||
|
@ -1309,7 +1309,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
|
||||
# run with gradients accumulation
|
||||
if model.require_grad_sync == False or (
|
||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False
|
||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
|
||||
):
|
||||
return outputs
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import random
|
||||
import warnings
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, OrderedDict, Tuple
|
||||
|
||||
|
@ -20,19 +21,19 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||
get_param_info,
|
||||
init_pipeline_optimizer,
|
||||
)
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MOE_MANAGER, MoECheckpointIO
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
|
||||
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
|
@ -67,8 +68,20 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
|
||||
pg_param_list = {
|
||||
dp_process_group: [],
|
||||
moe_extra_dp_process_group: [],
|
||||
}
|
||||
for param in model.parameters():
|
||||
if is_moe_tensor(param):
|
||||
pg_param_list[moe_extra_dp_process_group].append(param)
|
||||
else:
|
||||
pg_param_list[dp_process_group].append(param)
|
||||
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
pg_to_param_list=pg_param_list,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
|
@ -83,9 +96,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
overlap_communication=overlap_communication,
|
||||
partition_grad=partition_grad,
|
||||
cpu_offload=cpu_offload,
|
||||
dp_process_group=dp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
moe_extra_dp_process_group=moe_extra_dp_process_group,
|
||||
)
|
||||
|
||||
|
||||
|
@ -107,8 +118,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
|
||||
|
||||
Args:
|
||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
precision (str, optional): Specifies the precision of parameters during training.
|
||||
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
|
||||
Defaults to 'fp16'.
|
||||
|
@ -144,14 +155,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
|
||||
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
|
||||
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
|
||||
use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
ep_size: int,
|
||||
extra_dp_size: int = 1,
|
||||
tp_size: int = 1,
|
||||
sp_size: int = 1,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
enable_all_optimization: bool = False,
|
||||
|
@ -184,32 +196,22 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
custom_policy: Policy = None,
|
||||
checkpoint_io: Optional[MoECheckpointIO] = None,
|
||||
) -> None:
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
world_size = dist.get_world_size()
|
||||
assert tp_size == 1, "Tensor parallel is not supported in MoE yet"
|
||||
assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet"
|
||||
|
||||
if enable_sequence_parallelism:
|
||||
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
world_size % (tp_size * pp_size) == 0
|
||||
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
|
||||
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=self.real_dp_size,
|
||||
fixed_ep_size=ep_size,
|
||||
fixed_pp_size=pp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
)
|
||||
world_size % (tp_size * pp_size * ep_size) == 0
|
||||
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
|
||||
|
||||
self.dp_size = world_size // (tp_size * pp_size)
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_info = MOE_MANAGER.get_info(0)[1]
|
||||
self.sp_size = sp_size
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
|
@ -219,43 +221,57 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.checkpoint_io = checkpoint_io
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param
|
||||
# See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient
|
||||
# we change pg mesh to (pp, dp, tp) for better moe performance
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
|
||||
assert (
|
||||
self.ep_size <= self.dp_size
|
||||
), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})."
|
||||
|
||||
# sync moe in outer dp group, and sync other param in global dp group
|
||||
if extra_dp_size > 1:
|
||||
ep_size = self.dp_size // extra_dp_size
|
||||
if use_ep_inside:
|
||||
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size)
|
||||
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1)
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}")
|
||||
else:
|
||||
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size)
|
||||
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2)
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}")
|
||||
self.moe_dp_size = self.dp_size // self.ep_size
|
||||
self.use_ep_inside = use_ep_inside
|
||||
if self.use_ep_inside:
|
||||
logger.info(f"MoE Parallel use ep inside dp.", ranks=[0])
|
||||
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size)
|
||||
else:
|
||||
self.moe_extra_dp_group = None
|
||||
logger.info(f"MoE Parallel use ep outside dp.", ranks=[0])
|
||||
warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.")
|
||||
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size)
|
||||
|
||||
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
||||
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
|
||||
logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0])
|
||||
logger.info(
|
||||
f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0]
|
||||
)
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(
|
||||
self.tp_axis
|
||||
) # TODO: support custom tp size for mixtral lm head
|
||||
self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis))
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
|
||||
# TODO: Currently moe only support partially sequence parallel
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
|
||||
self.custom_policy = custom_policy
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.custom_policy = custom_policy
|
||||
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis)
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
|
||||
)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
# TODO: Currently moe only support partially sequence parallel
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
|
@ -267,6 +283,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
ep_group=self.ep_group,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
|
@ -323,7 +340,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(
|
||||
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
||||
dataset,
|
||||
num_replicas=self.dp_size,
|
||||
rank=dist.get_rank(self.global_dp_group),
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
|
@ -346,9 +366,20 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
|
||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||
if self.checkpoint_io is None:
|
||||
self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
self.checkpoint_io = MoECheckpointIO(
|
||||
self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
||||
)
|
||||
else:
|
||||
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
self.checkpoint_io = self.checkpoint_io(
|
||||
self.global_dp_group,
|
||||
self.pp_group,
|
||||
self.tp_group,
|
||||
ep_group=self.ep_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
zero_stage=self.zero_stage,
|
||||
)
|
||||
if hasattr(self.checkpoint_io, "moe_info"):
|
||||
self.checkpoint_io.moe_info = self.moe_info
|
||||
return self.checkpoint_io
|
||||
|
||||
def configure(
|
||||
|
@ -366,7 +397,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
module=model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.dp_group,
|
||||
dp_group=self.global_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=use_ddp,
|
||||
|
@ -392,15 +423,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
else:
|
||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = HybridParallelZeroOptimizer(
|
||||
optimizer = MoeHybridParallelZeroOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
dp_process_group=self.global_dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
moe_extra_dp_process_group=self.moe_extra_dp_group,
|
||||
moe_extra_dp_process_group=self.moe_dp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
|
|
|
@ -2,5 +2,12 @@ from .checkpoint_io_base import CheckpointIO
|
|||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .moe_checkpoint import MoECheckpointIO
|
||||
|
||||
__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
|
||||
__all__ = [
|
||||
"CheckpointIO",
|
||||
"CheckpointIndexFile",
|
||||
"GeneralCheckpointIO",
|
||||
"HybridParallelCheckpointIO",
|
||||
"MoECheckpointIO",
|
||||
]
|
||||
|
|
|
@ -70,13 +70,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dp_group = dp_group
|
||||
self.global_dp_group = dp_group
|
||||
self.pp_group = pp_group
|
||||
self.tp_group = tp_group
|
||||
self.dp_rank = dist.get_rank(self.dp_group)
|
||||
self.dp_rank = dist.get_rank(self.global_dp_group)
|
||||
self.tp_rank = dist.get_rank(self.tp_group)
|
||||
self.pp_rank = dist.get_rank(self.pp_group)
|
||||
self.dp_size = dist.get_world_size(dp_group)
|
||||
self.global_dp_size = dist.get_world_size(dp_group)
|
||||
self.pp_size = dist.get_world_size(pp_group)
|
||||
self.tp_size = dist.get_world_size(tp_group)
|
||||
self.use_zero = zero_stage > 0
|
||||
|
@ -433,7 +433,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
|
||||
optimizer,
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
dp_group=self.global_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
|
@ -727,7 +727,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=self.dp_group,
|
||||
dp_group=self.global_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
use_zero=self.use_zero,
|
||||
inplace=False,
|
||||
|
@ -932,12 +932,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
# Shard state along data parallel group when using Zero.
|
||||
if self.use_zero:
|
||||
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
|
||||
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.dp_size
|
||||
slice_size = v.numel() // self.global_dp_size
|
||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import get_global_rank
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
|
||||
|
@ -19,15 +20,16 @@ from colossalai.checkpoint_io.utils import (
|
|||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
search_tp_partition_dim,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
try:
|
||||
|
@ -36,21 +38,30 @@ except ImportError:
|
|||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
||||
class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
def __init__(
|
||||
self,
|
||||
dp_group: ProcessGroup,
|
||||
global_dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
|
||||
moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
|
||||
self.ep_group = moe_info.ep_group
|
||||
self.ep_size = moe_info.ep_size
|
||||
self.ep_rank = moe_info.ep_rank
|
||||
self.real_dp_rank = moe_info.dp_rank
|
||||
super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose)
|
||||
self.global_dp_group = global_dp_group
|
||||
self.global_dp_rank = dist.get_rank(global_dp_group)
|
||||
self.global_dp_size = dist.get_world_size(global_dp_group)
|
||||
self.pp_group = pp_group
|
||||
self.tp_group = tp_group
|
||||
|
||||
self.moe_dp_group = moe_dp_group
|
||||
self.moe_dp_size = dist.get_world_size(moe_dp_group)
|
||||
self.moe_dp_rank = dist.get_rank(moe_dp_group)
|
||||
self.ep_group = ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group)
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
|
||||
@staticmethod
|
||||
def _model_sharder(
|
||||
|
@ -134,7 +145,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.real_dp_rank != 0:
|
||||
if self.moe_dp_rank != 0:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
|
@ -144,7 +155,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
|
||||
state_dict_shard = MoECheckpointIO._model_sharder(
|
||||
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
|
||||
)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
|
@ -234,11 +245,12 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
state: OrderedDict,
|
||||
param: torch.Tensor,
|
||||
original_shape: torch.Size,
|
||||
dp_group: ProcessGroup,
|
||||
global_dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
use_zero: bool,
|
||||
inplace: bool,
|
||||
is_moe_param: bool,
|
||||
moe_dp_group: ProcessGroup = None,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
|
@ -248,7 +260,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
|
||||
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
dp_group (ProcessGroup): The process group of data parallel.
|
||||
global_dp_group (ProcessGroup): The process group of data parallel.
|
||||
tp_group (ProcessGroup): The process group of tensor parallel.
|
||||
use_zero (bool): Whether Zero is used.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
|
@ -257,27 +269,47 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
Returns:
|
||||
OrderedDict: The complete optimizer state of given parameter.
|
||||
"""
|
||||
dp_size = dist.get_world_size(dp_group)
|
||||
global_dp_size = dist.get_world_size(global_dp_group)
|
||||
tp_size = dist.get_world_size(tp_group)
|
||||
moe_dp_size = dist.get_world_size(moe_dp_group) if moe_dp_group is not None else 1
|
||||
current_shape = param.shape
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
v = v.cuda()
|
||||
|
||||
# First gather Zero shards.
|
||||
if use_zero and not is_moe_param:
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
if use_zero and is_moe_param and moe_dp_size > 1:
|
||||
moe_dp_rank = dist.get_rank(moe_dp_group)
|
||||
dst = get_global_rank(moe_dp_group, 0)
|
||||
if moe_dp_rank == 0:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
|
||||
dist.gather(v, gather_tensor, group=moe_dp_group, dst=dst)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
else:
|
||||
dist.gather(v, group=moe_dp_group, dst=dst)
|
||||
|
||||
elif use_zero and not is_moe_param and global_dp_size > 1:
|
||||
dp_rank = dist.get_rank(global_dp_group)
|
||||
dst = get_global_rank(global_dp_group, 0)
|
||||
if dp_rank == 0:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(global_dp_size)]
|
||||
dist.gather(v, gather_tensor, group=global_dp_group, dst=dst)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
else:
|
||||
dist.gather(v, group=global_dp_group, dst=dst)
|
||||
|
||||
# Then gather TP shards.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
|
||||
if partition_dim is not None:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=tp_group)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
|
||||
tp_rank = dist.get_rank(tp_group)
|
||||
dst = get_global_rank(tp_group, 0)
|
||||
if tp_rank == 0:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
|
||||
dist.gather(v, gather_tensor, group=tp_group, dst=dst)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
else:
|
||||
dist.gather(v, group=tp_group, dst=dst)
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
@ -286,8 +318,9 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
def _optimizer_sharder(
|
||||
optimizer: OptimizerWrapper,
|
||||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
global_dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
size_per_shard: int = 1024,
|
||||
only_moe_param: bool = False,
|
||||
):
|
||||
|
@ -296,7 +329,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
@ -305,22 +337,23 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
param_id = param_info["param2id"][id(working_param)]
|
||||
original_shape = param_info["param2shape"][id(working_param)]
|
||||
state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state_ = MoECheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=dp_group,
|
||||
global_dp_group=global_dp_group,
|
||||
moe_dp_group=moe_dp_group,
|
||||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False,
|
||||
is_moe_param=is_moe_tensor(working_param),
|
||||
is_moe_param=is_moe_tensor(working_param), # TODO: Check correctness here
|
||||
)
|
||||
|
||||
if only_moe_param and not is_moe_tensor(working_param):
|
||||
continue
|
||||
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
@ -359,25 +392,28 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Devices along the same dp_group share the same copies of states when zero is not used.
|
||||
# In this case only let the device with dp_rank == 0 save the model.
|
||||
if not self.use_zero and self.real_dp_rank != 0:
|
||||
# If optim states are not sharded, other ranks don't need to participate in gather.
|
||||
if not self.use_zero and self.moe_dp_rank != 0:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
||||
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
|
||||
state_dict_shard = MoECheckpointIO._optimizer_sharder(
|
||||
optimizer,
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
global_dp_group=self.global_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
size_per_shard=size_per_shard,
|
||||
only_moe_param=self.ep_rank != 0,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
|
||||
# e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather
|
||||
# rank 0 saves moe & non-moe params; rank 1 only saves moe params
|
||||
# rank 3 & 4 save nothing
|
||||
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0
|
||||
|
||||
if self.pp_size == 1 and self.ep_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
|
@ -596,7 +632,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
OrderedDict: The sharded optimizer state of the given parameter.
|
||||
"""
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# Shard state along tensor parallel group.
|
||||
|
@ -606,24 +641,218 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
|
||||
|
||||
# Shard state along data parallel group when using Zero.
|
||||
if self.use_zero and not is_moe_param:
|
||||
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
|
||||
if self.use_zero and not is_moe_param and self.global_dp_size > 1:
|
||||
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.dp_size
|
||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||
slice_size = v.numel() // self.global_dp_size
|
||||
v = v.split(slice_size, dim=0)[self.global_dp_rank]
|
||||
|
||||
elif self.use_zero and is_moe_param and self.moe_dp_size > 1:
|
||||
# LowLevelZeRO pads by global dp size for now.
|
||||
# TODO: update both to use moe dp size
|
||||
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.moe_dp_size
|
||||
v = v.split(slice_size, dim=0)[self.moe_dp_rank]
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
raise NotImplementedError
|
||||
"""Migration from MoEHybridParallelCheckpointIO. These functions mostly deals with unsharded saving,
|
||||
and can be savely deleted since large MoE models are often saved in shards.
|
||||
"""
|
||||
|
||||
# Copied from colossalai.moe
|
||||
def pre_save_model(self, model: nn.Module) -> dict:
|
||||
state_dict = model.state_dict()
|
||||
for name, param in model.named_parameters():
|
||||
if ".experts." in name and is_moe_tensor(param):
|
||||
ep_group = param.ep_group
|
||||
ep_rank = dist.get_rank(ep_group)
|
||||
ep_size = dist.get_world_size(ep_group)
|
||||
# TODO: check correctness here
|
||||
# dp_rank = get_dp_rank(param)
|
||||
dp_rank = dist.get_rank(self.global_dp_group)
|
||||
if dp_rank == 0:
|
||||
param = param.data.cuda()
|
||||
if ep_rank == 0:
|
||||
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
||||
else:
|
||||
all_param = None
|
||||
# gather param from every ep rank
|
||||
# dist.all_gather(all_param, param, group=ep_group)
|
||||
dist.gather(param, all_param, group=ep_group)
|
||||
if ep_rank == 0:
|
||||
all_param = torch.cat(all_param, dim=0)
|
||||
state_dict[name] = all_param.cpu()
|
||||
|
||||
if self.pp_size > 1:
|
||||
if self.dp_rank == 0:
|
||||
out = [None for _ in range(self.pp_size)]
|
||||
dist.gather_object(state_dict, out, group=self.pp_group)
|
||||
if self.pp_rank == 0:
|
||||
new_state_dict = {}
|
||||
for o in out:
|
||||
new_state_dict.update(o)
|
||||
state_dict = new_state_dict
|
||||
dist.barrier()
|
||||
return state_dict
|
||||
|
||||
def save_unsharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
use_safetensors: bool,
|
||||
):
|
||||
state_dict = self.pre_save_model(model)
|
||||
if dist.get_rank() == 0:
|
||||
torch.save(state_dict, checkpoint)
|
||||
dist.barrier()
|
||||
|
||||
# Copied from colossalai.moe
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
raise NotImplementedError
|
||||
"""
|
||||
Save optimizer state dict to a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
|
||||
checkpoint (str): Path to save optimizer state_dict.
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
|
||||
# optimizer states of parameters kept by local device('s pipeline stage)
|
||||
local_states = dict()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
# working param is needed for obtaining correct param_id
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
# gather complete state from tp shards & dp shards
|
||||
param_id = optimizer.param_info["param2id"][id(working_param)]
|
||||
local_states[param_id] = self.pre_save_optim(
|
||||
state,
|
||||
working_param,
|
||||
inplace=False,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
else:
|
||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||
states_list = [None for _ in range(self.pp_size)]
|
||||
dist.barrier(self.pp_group)
|
||||
# dist.all_gather_object(states_list, local_states, self.pp_group)
|
||||
dist.gather_object(local_states, states_list, self.pp_group)
|
||||
|
||||
# Only the master rank do the saving.
|
||||
if self.coordinator.is_master():
|
||||
state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
|
||||
for _states in states_list:
|
||||
state_dict["state"].update(_states)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
dist.barrier()
|
||||
|
||||
# Copied from colossalai.moe
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
|
||||
raise NotImplementedError
|
||||
"""
|
||||
Load optimizer from a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the checkpoint file.
|
||||
"""
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
):
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
if id(working_param) in optimizer.param_info["param2id"]:
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
else:
|
||||
None
|
||||
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
|
||||
# Load param_groups.
|
||||
updated_groups = []
|
||||
saved_groups = state_dict["param_groups"]
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
|
||||
updated_groups.append(new_pg)
|
||||
|
||||
# ep extra group
|
||||
# if MOE_MANAGER.parallel == "EP":
|
||||
if self.ep_size > 1:
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1][
|
||||
"params"
|
||||
] # Only keep the parameters kept by current pipeline stage.
|
||||
for param in new_pg["params"]:
|
||||
param.data = param.data.to(torch.float32)
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
id_map = {}
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
if param_id is not None:
|
||||
id_map[param_id] = param
|
||||
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
device = param.device
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True,
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
dist.barrier()
|
|
@ -242,6 +242,7 @@ def save_state_dict_shards(
|
|||
shard_filenames = []
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
shard, current_size = shard_pair
|
||||
# Just loop over the sharder and gather to other ranks if not master
|
||||
if not is_master:
|
||||
del shard
|
||||
continue
|
||||
|
|
|
@ -147,7 +147,7 @@ class ProcessGroupMesh:
|
|||
ProcessGroup: The process group with the given ranks.
|
||||
"""
|
||||
ranks_in_group = sorted(ranks_in_group)
|
||||
if tuple(ranks_in_group) not in self._group_to_ranks:
|
||||
if tuple(ranks_in_group) not in self._ranks_to_group:
|
||||
group = dist.new_group(ranks_in_group, backend=backend)
|
||||
self._ranks_to_group[tuple(ranks_in_group)] = group
|
||||
self._group_to_ranks[group] = tuple(ranks_in_group)
|
||||
|
@ -244,19 +244,25 @@ class ProcessGroupMesh:
|
|||
return target_group
|
||||
|
||||
def get_group_along_axis(
|
||||
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
||||
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
||||
) -> ProcessGroup:
|
||||
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
|
||||
|
||||
Args:
|
||||
axis (int): Axis along which the process groups are created.
|
||||
axis (int or list of int): Axes along which the process groups are created.
|
||||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
||||
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||
"""
|
||||
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
|
||||
indices_at_axis = indices_at_axis
|
||||
if indices_at_axis is None:
|
||||
if isinstance(axis, (list, tuple)):
|
||||
indices_at_axis = list(list(range(self._shape[ax])) for ax in axis)
|
||||
else:
|
||||
indices_at_axis = list(range(self._shape[axis]))
|
||||
|
||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
|
||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||
if ranks_in_group not in self._ranks_to_group:
|
||||
|
|
|
@ -247,16 +247,16 @@ class BatchBucket:
|
|||
self._sequences_dict[seq.request_id] = seq
|
||||
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
|
||||
# TODO external (rename): modify Sequence.sentence_len to seq_len
|
||||
self._sequence_lengths[
|
||||
self._current_batch_size : self._current_batch_size + num_seqs_to_add
|
||||
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
|
||||
self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
|
||||
torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
|
||||
)
|
||||
# NOTE block tables to be updated by kvcache manager
|
||||
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
|
||||
if alloc_block_tables is not None:
|
||||
# copy block ids from provided block tables
|
||||
self._block_tables[
|
||||
self._current_batch_size : self._current_batch_size + num_seqs_to_add
|
||||
] = alloc_block_tables
|
||||
self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
|
||||
alloc_block_tables
|
||||
)
|
||||
elif alloc_block_tables_fn:
|
||||
alloc_block_tables_fn(
|
||||
block_tables,
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
"""
|
||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.generation import GenerationConfig
|
||||
|
@ -82,9 +83,9 @@ class InputMetaData(RPC_PARAM):
|
|||
dtype: torch.dtype = torch.float32
|
||||
use_spec_dec: bool = False
|
||||
num_tokens_to_verify: int = 0
|
||||
batch_token_ids: Optional[
|
||||
List[List[int]]
|
||||
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
|
||||
batch_token_ids: Optional[List[List[int]]] = (
|
||||
None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
|
||||
)
|
||||
|
||||
def to_rpc_param(self) -> Dict[str, any]:
|
||||
return {
|
||||
|
@ -202,9 +203,9 @@ class InferenceConfig(RPC_PARAM):
|
|||
prompt_template: Optional[str] = None
|
||||
do_sample: bool = False
|
||||
beam_width: int = 1 # TODO: beam search is not support for now
|
||||
prefill_ratio: Optional[
|
||||
float
|
||||
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||
prefill_ratio: Optional[float] = (
|
||||
1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||
)
|
||||
pad_input: bool = False
|
||||
early_stopping: Optional[bool] = False
|
||||
top_k: Optional[int] = 50
|
||||
|
@ -234,7 +235,9 @@ class InferenceConfig(RPC_PARAM):
|
|||
high_precision: Optional[bool] = False
|
||||
|
||||
# cuda_graph
|
||||
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph: bool = (
|
||||
False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
)
|
||||
max_context_len_to_capture: int = 512
|
||||
|
||||
# StreamingLLM (sliding window attention with attention sinks)
|
||||
|
@ -393,3 +396,49 @@ class ModelShardInferenceConfig:
|
|||
use_cuda_kernel: bool = False
|
||||
use_spec_dec: bool = False
|
||||
use_flash_attn: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionGenerationConfig:
|
||||
"""
|
||||
Param for diffusion model forward
|
||||
"""
|
||||
|
||||
prompt_2: Optional[Union[str, List[str]]] = None
|
||||
prompt_3: Optional[Union[str, List[str]]] = None
|
||||
height: Optional[int] = None
|
||||
width: Optional[int] = None
|
||||
num_inference_steps: int = None
|
||||
timesteps: List[int] = None
|
||||
guidance_scale: float = None
|
||||
negative_prompt: Optional[Union[str, List[str]]] = (
|
||||
None # NOTE(@lry89757) in pixart default to "", in sd3 default to None
|
||||
)
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None
|
||||
num_images_per_prompt: Optional[int] = None
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
|
||||
latents: Optional[torch.FloatTensor] = None
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
output_type: Optional[str] = None # "pil"
|
||||
return_dict: bool = None
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None
|
||||
clip_skip: Optional[int] = None
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
|
||||
callback_on_step_end_tensor_inputs: List[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
# NOTE(@lry89757) Only return the dict that not the default value None
|
||||
result = {}
|
||||
for field in fields(self):
|
||||
value = getattr(self, field.name)
|
||||
if value is not None:
|
||||
result[field.name] = value
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
|
||||
return cls(**kwargs)
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
|
||||
"""
|
||||
Init Model for Engine
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
|
||||
"""
|
||||
Generate ouptput for coming requests
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, prompts, request_ids=None, **kwargs):
|
||||
"""
|
||||
Add new request to Engine
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def step(self):
|
||||
"""
|
||||
Perform one new step forward
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _verify_args(self):
|
||||
"""
|
||||
Verify the parameters and members of class
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self):
|
||||
"""
|
||||
Use cuda graph to capture model
|
||||
"""
|
||||
return NotImplementedError("This method should be implemented by subclasses")
|
||||
|
||||
def _shardformer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Initialize ShardConfig and replace the model with shardformer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model optimized by Shardformer.
|
||||
"""
|
||||
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model
|
|
@ -0,0 +1,200 @@
|
|||
from itertools import count
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from torch import distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
|
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.struct import DiffusionSequence
|
||||
from colossalai.inference.utils import get_model_size, get_model_type
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .base_engine import BaseEngine
|
||||
from .request_handler import NaiveRequestHandler
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
|
||||
class DiffusionEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: DiffusionPipeline | str,
|
||||
inference_config: InferenceConfig = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy | type[Policy] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
self.high_precision = inference_config.high_precision
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
|
||||
self.model_type = get_model_type(model_or_path=model_or_path)
|
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.request_handler = NaiveRequestHandler()
|
||||
|
||||
self.counter = count()
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe"
|
||||
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[str, nn.Module, DiffusionPipeline],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
if isinstance(model_or_path, str):
|
||||
model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype)
|
||||
policy_map_key = model.__class__.__name__
|
||||
model = DiffusionPipe(model)
|
||||
elif isinstance(model_or_path, DiffusionPipeline):
|
||||
policy_map_key = model_or_path.__class__.__name__
|
||||
model = DiffusionPipe(model_or_path)
|
||||
else:
|
||||
self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
self.device = get_accelerator().get_current_device()
|
||||
if self.verbose:
|
||||
self.logger.info(f"the device is {self.device}")
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if model_policy is None:
|
||||
model_policy = model_policy_map.get(policy_map_key)
|
||||
|
||||
if not isinstance(model_policy, Policy):
|
||||
try:
|
||||
model_policy = model_policy()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.model = model.to(self.device)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
generation_config: DiffusionGenerationConfig = None,
|
||||
**kwargs,
|
||||
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
|
||||
""" """
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||
|
||||
with torch.inference_mode():
|
||||
if prompts is not None:
|
||||
self.add_request(
|
||||
request_ids=request_ids,
|
||||
prompts=prompts,
|
||||
**gen_config_dict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output_reqs_list = []
|
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
self.generation_config_dict = gen_config_dict
|
||||
|
||||
while self.request_handler.check_unfinished_reqs():
|
||||
output_reqs_list += self.step()
|
||||
|
||||
return output_reqs_list
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
prompts: Union[List[str], str],
|
||||
request_ids: Union[List[int], int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs)
|
||||
prompts_num = len(prompts)
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
assert len(request_ids) == prompts_num
|
||||
request_id = request_ids[i]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
|
||||
seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config)
|
||||
|
||||
self.request_handler.add_sequence(seq)
|
||||
|
||||
def step(self) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. run forward to get List[Image]
|
||||
Returns:
|
||||
List[PIL.Image.Image]: Image Generated by one step.
|
||||
"""
|
||||
|
||||
input = self.request_handler.schedule()
|
||||
ret = self.model(prompt=input.prompt, **input.generation_config.to_dict())
|
||||
return ret
|
|
@ -1,58 +1,24 @@
|
|||
import time
|
||||
from itertools import count
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import PIL.Image
|
||||
import torch.nn as nn
|
||||
from torch import distributed as dist
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.utils import ModelType, get_model_type
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .request_handler import RequestHandler
|
||||
|
||||
__all__ = ["InferenceEngine"]
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = {
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"BaichuanForCausalLM": AutoModelForCausalLM,
|
||||
}
|
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
Args:
|
||||
model_or_path (nn.Module or str): Path or nn.Module of this model.
|
||||
model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model.
|
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
|
@ -61,567 +27,68 @@ class InferenceEngine:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: InferenceConfig,
|
||||
model_or_path: Union[nn.Module, str, DiffusionPipeline],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
||||
inference_config: InferenceConfig = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
self.high_precision = inference_config.high_precision
|
||||
self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__
|
||||
self.model_type = get_model_type(model_or_path=model_or_path)
|
||||
self.engine = None
|
||||
if self.model_type == ModelType.LLM:
|
||||
from .llm_engine import LLMEngine
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
self.engine = LLMEngine(
|
||||
model_or_path=model_or_path,
|
||||
tokenizer=tokenizer,
|
||||
inference_config=inference_config,
|
||||
verbose=verbose,
|
||||
model_policy=model_policy,
|
||||
)
|
||||
elif self.model_type == ModelType.DIFFUSION_MODEL:
|
||||
from .diffusion_engine import DiffusionEngine
|
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
||||
# DISCUSS maybe move this into batch info?
|
||||
|
||||
self.counter = count()
|
||||
|
||||
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
||||
if self.use_cuda_graph:
|
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||
self.graph_memory_pool = None # Set during graph capture.
|
||||
if verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture on")
|
||||
|
||||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.engine = DiffusionEngine(
|
||||
model_or_path=model_or_path,
|
||||
inference_config=inference_config,
|
||||
verbose=verbose,
|
||||
model_policy=model_policy,
|
||||
)
|
||||
elif self.model_type == ModelType.UNKNOWN:
|
||||
self.logger.error(f"Model Type either Difffusion or LLM!")
|
||||
|
||||
self._initialized = True
|
||||
self._verify_args()
|
||||
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
pretrained_path = None
|
||||
if isinstance(model_or_path, str):
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
if arch in _supported_models.keys():
|
||||
if arch is "BaichuanForCausalLM":
|
||||
self.logger.warning(
|
||||
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||
)
|
||||
ctx = LazyInitContext(default_device="cuda")
|
||||
with ctx:
|
||||
model = _supported_models[arch].from_pretrained(
|
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||
)
|
||||
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||
else:
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||
)
|
||||
else:
|
||||
model = model_or_path
|
||||
|
||||
self.model_config = model.config
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
self.device = get_accelerator().get_current_device()
|
||||
if self.verbose:
|
||||
self.logger.info(f"the device is {self.device}")
|
||||
|
||||
model = model.to(self.dtype).eval()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if model_policy is None:
|
||||
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
|
||||
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
|
||||
model_policy = model_policy_map.get(model_policy_key)
|
||||
|
||||
if not isinstance(model_policy, Policy):
|
||||
try:
|
||||
model_policy = model_policy()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.model = ModelWrapper(model).to(self.device)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if pretrained_path:
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
cpt_io = InferCheckpoint_io()
|
||||
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||
assert if_has_index_file, "the model path is invalid"
|
||||
cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture begin")
|
||||
|
||||
t_capture_begin = time.perf_counter()
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
|
||||
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
|
||||
self.graph_block_tables[0, :] = np.arange(
|
||||
0, max_num_blocks
|
||||
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
output_tensor = torch.zeros(
|
||||
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
||||
)
|
||||
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
|
||||
|
||||
max_num_seqs = self.inference_config.max_batch_size
|
||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
||||
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
sequence_lengths[0] = torch.tensor(
|
||||
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
|
||||
).cuda()
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
if self.verbose:
|
||||
self.logger.info(f"batch size {batch_size} graph capturing")
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=block_tables[:batch_size],
|
||||
sequence_lengths=sequence_lengths[:batch_size],
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
batch_size=batch_size,
|
||||
is_prompts=False,
|
||||
use_cuda_graph=True,
|
||||
high_precision=False,
|
||||
kv_seq_len=sequence_lengths[:batch_size].max().item(),
|
||||
head_dim=head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
input_tokens_ids[:batch_size],
|
||||
output_tensor[:batch_size],
|
||||
input_meta_data,
|
||||
k_caches=k_cache,
|
||||
v_caches=v_cache,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
|
||||
t_capture_end = time.perf_counter()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
"""Verify the input args"""
|
||||
if not isinstance(self.inference_config, InferenceConfig):
|
||||
raise TypeError("Invalid type of inference config provided.")
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
if isinstance(self.model, ModelWrapper):
|
||||
model = self.model.module
|
||||
assert (
|
||||
model.__class__.__name__ in _supported_models.keys()
|
||||
), f"Model {self.model.__class__.__name__} is not supported."
|
||||
|
||||
def _shardformer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Initialize ShardConfig and replace the model with shardformer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model optimized by Shardformer.
|
||||
"""
|
||||
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model
|
||||
|
||||
def enable_spec_dec(
|
||||
self,
|
||||
drafter_model: nn.Module = None,
|
||||
n_spec_tokens: int = None,
|
||||
use_glide_drafter: bool = False,
|
||||
) -> None:
|
||||
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||
|
||||
Args:
|
||||
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
||||
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
||||
If True, the drafter model will be replaced by a glide model.
|
||||
|
||||
```python
|
||||
...
|
||||
engine = InferenceEngine(model, tokenizer, inference_config)
|
||||
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||
engine.generate(...) # Speculative Decoding
|
||||
|
||||
engine.disable_spec_dec()
|
||||
engine.generate(...) # Normal generation
|
||||
|
||||
engine.enable_spec_dec()
|
||||
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
||||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
if n_spec_tokens is not None:
|
||||
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
||||
self.n_spec_tokens = n_spec_tokens
|
||||
if drafter_model is not None:
|
||||
assert isinstance(drafter_model, nn.Module)
|
||||
# overwrite the drafter, if exists
|
||||
self.clear_spec_dec()
|
||||
self.drafter_model = drafter_model
|
||||
self.drafter = Drafter(
|
||||
self.drafter_model,
|
||||
self.tokenizer,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# check if the provided drafter model is compatible with GLIDE structure
|
||||
# when `use_glide_drafter` is set to True
|
||||
if (
|
||||
use_glide_drafter
|
||||
and hasattr(drafter_model, "model")
|
||||
and hasattr(drafter_model.model, "layers")
|
||||
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
||||
):
|
||||
self.use_glide = use_glide_drafter
|
||||
elif use_glide_drafter:
|
||||
self.logger.warning(
|
||||
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
||||
f"but the provided drafter model is not compatible with GLIDE structure."
|
||||
f"Falling back to use the default drafter model (non-GLIDE)."
|
||||
)
|
||||
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||
# using speculative decoding for subsequent generations
|
||||
self.use_spec_dec = True
|
||||
|
||||
def disable_spec_dec(self) -> None:
|
||||
"""Disable using speculative decoding for subsequent generations."""
|
||||
self.request_handler.unset_spec_dec_mode()
|
||||
# set back to the maximum number of tokens to speculate
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def clear_spec_dec(self) -> None:
|
||||
"""Clear relatable structures of speculative decoding, if exist."""
|
||||
if self.use_spec_dec:
|
||||
self.disable_spec_dec()
|
||||
if self.drafter_model or self.drafter:
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
torch.cuda.empty_cache()
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def steps_spec_dec(self) -> List[Sequence]:
|
||||
"""
|
||||
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
||||
with many steps of speculating by a drafter model as well as verifying by a main model.
|
||||
|
||||
Returns:
|
||||
List[Sequence]: finished sequences generated by one step.
|
||||
"""
|
||||
batch = self.request_handler.schedule() # prefill batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
||||
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
|
||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
# append new inputs to the batch, temporarily
|
||||
batch.append_batch_tokens(next_tokens)
|
||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(1)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
while True:
|
||||
# HACK Retrieve the running batch
|
||||
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||
batch = self.request_handler.running_bb # running batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
# 3. Decoding - Drafter model speculates `n` tokens
|
||||
glide_input = None
|
||||
if self.use_glide:
|
||||
glide_input = GlideInput(
|
||||
batch.get_block_table_tensor(),
|
||||
self.k_cache[-1], # use kv cahces of the last layer
|
||||
self.v_cache[-1],
|
||||
batch.get_sequence_lengths(),
|
||||
n_spec_tokens=self.n_spec_tokens,
|
||||
)
|
||||
|
||||
drafter_out = self.drafter.speculate(
|
||||
input_token_ids,
|
||||
self.n_spec_tokens,
|
||||
drafter_past_key_values,
|
||||
glide_input=glide_input,
|
||||
)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
drafter_spec_length = drafter_out.speculated_length
|
||||
|
||||
for next_token_id_spec in next_token_ids_spec:
|
||||
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||
cur_length = batch.seq_lengths[0].item()
|
||||
if already_allocated_kv_len < cur_length:
|
||||
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
||||
already_allocated_kv_len = cur_length
|
||||
|
||||
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||
if drafter_spec_length < batch.num_tokens_to_verify:
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
|
||||
# 5. Compare and process the results
|
||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||
|
||||
# revoke appended tokens for each Sequence in the current batch
|
||||
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||
|
||||
# append the last correct token generated by the main model
|
||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||
|
||||
# trim past key values of the drafter model
|
||||
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||
)
|
||||
|
||||
# prepare inputs for the next round of speculation
|
||||
n = 1 if n_matches < drafter_spec_length else 2
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(n)
|
||||
|
||||
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||
finished_sequences = self.request_handler.update()
|
||||
if len(finished_sequences) > 0:
|
||||
break
|
||||
|
||||
# Reset back the number of speculated tokens of the batch,
|
||||
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
||||
# by the drafter is less than the number of speculated tokens set to the engine.
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
||||
|
||||
return finished_sequences
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
assert self._initialized, "Engine must be initialized"
|
||||
|
||||
def generate(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
|
||||
Args:
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
||||
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
||||
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||
|
||||
with torch.inference_mode():
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(
|
||||
request_ids=request_ids,
|
||||
prompts=prompts,
|
||||
prompts_token_ids=prompts_token_ids,
|
||||
**gen_config_dict,
|
||||
)
|
||||
|
||||
output_seqs_list = []
|
||||
total_tokens_list = []
|
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
self.generation_config_dict = gen_config_dict
|
||||
|
||||
if self.use_spec_dec:
|
||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.steps_spec_dec()
|
||||
else:
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.step()
|
||||
|
||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||
|
||||
for seq in output_seqs_list:
|
||||
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
|
||||
|
||||
if return_token_ids:
|
||||
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
|
||||
return output_str, output_tokens_list
|
||||
else:
|
||||
return output_str
|
||||
|
||||
@property
|
||||
def has_prompt_template(self) -> bool:
|
||||
""" """
|
||||
return self.inference_config.prompt_template is not None
|
||||
|
||||
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
|
||||
"""
|
||||
This method will format the input prompt according to the prompt template given to the InferenceConfig.
|
||||
"""
|
||||
assert (
|
||||
self.has_prompt_template
|
||||
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
|
||||
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||
elif isinstance(prompts, str):
|
||||
return self.inference_config.prompt_template.format(input_text=prompts)
|
||||
else:
|
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -631,168 +98,36 @@ class InferenceEngine:
|
|||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
kwargs: for LLM, it could be max_length, max_new_tokens, etc
|
||||
for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers
|
||||
"""
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs)
|
||||
|
||||
# apply the prompt template to the input prompts
|
||||
def step(self):
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
return self.engine.step()
|
||||
|
||||
if self.has_prompt_template and prompts is not None:
|
||||
prompts = self.format_prompt(prompts)
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
# list of torch Tensor
|
||||
if isinstance(prompts_token_ids, list):
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
|
||||
)
|
||||
|
||||
assert (
|
||||
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
|
||||
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
assert len(request_ids) == prompts_num
|
||||
request_id = request_ids[i]
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
The Design logic of getattr, setattr:
|
||||
1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine.
|
||||
2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx
|
||||
So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine)
|
||||
"""
|
||||
if self.__dict__.get("_initialized", False):
|
||||
if name in self.__dict__:
|
||||
return self.__dict__[name]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
if prompts == None:
|
||||
prompt = None
|
||||
return getattr(self.engine, name)
|
||||
else:
|
||||
return self.__dict__[name]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if self.__dict__.get("_initialized", False):
|
||||
if name in self.__dict__:
|
||||
self.__dict__[name] = value
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
|
||||
max_length = kwargs.get("max_length", None)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", None)
|
||||
if max_length is None and max_new_tokens is None:
|
||||
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
|
||||
elif max_length is not None:
|
||||
max_new_tokens = max_length - len(prompts_token_ids[i])
|
||||
|
||||
if not self.inference_config.enable_streamingllm:
|
||||
assert (
|
||||
self.inference_config.max_output_len >= max_new_tokens
|
||||
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
||||
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
max_output_len=max_new_tokens,
|
||||
ignore_eos=self.inference_config.ignore_eos,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
||||
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
||||
input_ids = batch.get_1D_inputs()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
if batch.is_prompts:
|
||||
n_tokens = sequence_lengths.sum().item()
|
||||
setattr(self.engine, name, value)
|
||||
else:
|
||||
n_tokens = batch.current_batch_size
|
||||
if batch.use_spec_dec:
|
||||
n_tokens = batch.num_tokens_to_verify + 1
|
||||
assert n_tokens == input_ids.size(0)
|
||||
n_tokens = n_tokens * batch.current_batch_size
|
||||
output_tensor = torch.zeros(
|
||||
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
|
||||
batch_token_ids = None
|
||||
if (
|
||||
self.generation_config.repetition_penalty != 1.0
|
||||
or self.generation_config.no_repeat_ngram_size > 0
|
||||
or self.generation_config.forced_eos_token_id is not None
|
||||
):
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph = False
|
||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||
use_cuda_graph = True
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=batch.get_block_table_tensor(),
|
||||
sequence_lengths=sequence_lengths,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
batch_size=batch.current_batch_size,
|
||||
is_prompts=batch.is_prompts,
|
||||
use_cuda_kernel=self.inference_config.use_cuda_kernel,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
high_precision=self.high_precision,
|
||||
kv_seq_len=sequence_lengths.max().item(),
|
||||
head_dim=batch.head_dim,
|
||||
dtype=batch.dtype,
|
||||
use_spec_dec=batch.use_spec_dec,
|
||||
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||
batch_token_ids=batch_token_ids,
|
||||
)
|
||||
|
||||
return input_ids, output_tensor, input_meta_data
|
||||
|
||||
def step(self) -> List[str]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. Get the input, inputinfo and output placeholder from the batchbucket
|
||||
3. Run model to generate the next token
|
||||
4. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||
5. Decode and return finished sequences.
|
||||
|
||||
Returns:
|
||||
List[str]: Decoded finished sequences generated by one step.
|
||||
"""
|
||||
|
||||
batch = self.request_handler.schedule()
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if self.inference_config.enable_streamingllm:
|
||||
updated_block_ids = batch.streamingllm_update_batch(
|
||||
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
||||
)
|
||||
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
||||
|
||||
next_tokens = search_tokens(
|
||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||
)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
||||
self.__dict__[name] = value
|
||||
|
|
|
@ -0,0 +1,758 @@
|
|||
import time
|
||||
from itertools import count
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed as dist
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .base_engine import BaseEngine
|
||||
from .request_handler import RequestHandler
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = {
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"BaichuanForCausalLM": AutoModelForCausalLM,
|
||||
}
|
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
|
||||
class LLMEngine(BaseEngine):
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
Args:
|
||||
model_or_path (nn.Module or str): Path or nn.Module of this model.
|
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
||||
inference_config: InferenceConfig = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Union[Policy, type[Policy]] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
self.high_precision = inference_config.high_precision
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
||||
# DISCUSS maybe move this into batch info?
|
||||
|
||||
self.counter = count()
|
||||
|
||||
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
||||
if self.use_cuda_graph:
|
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||
self.graph_memory_pool = None # Set during graph capture.
|
||||
if verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture on")
|
||||
|
||||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
pretrained_path = None
|
||||
if isinstance(model_or_path, str):
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
if arch in _supported_models.keys():
|
||||
if arch == "BaichuanForCausalLM":
|
||||
self.logger.warning(
|
||||
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||
)
|
||||
ctx = LazyInitContext(default_device="cuda")
|
||||
with ctx:
|
||||
model = _supported_models[arch].from_pretrained(
|
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||
)
|
||||
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||
else:
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||
)
|
||||
else:
|
||||
model = model_or_path
|
||||
|
||||
self.model_config = model.config
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
self.device = get_accelerator().get_current_device()
|
||||
if self.verbose:
|
||||
self.logger.info(f"the device is {self.device}")
|
||||
|
||||
model = model.to(self.dtype).eval()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if model_policy is None:
|
||||
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
|
||||
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
|
||||
model_policy = model_policy_map.get(model_policy_key)
|
||||
|
||||
if not isinstance(model_policy, Policy):
|
||||
try:
|
||||
model_policy = model_policy()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.model = ModelWrapper(model).to(self.device)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if pretrained_path:
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
cpt_io = InferCheckpoint_io()
|
||||
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||
assert if_has_index_file, "the model path is invalid"
|
||||
cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture begin")
|
||||
|
||||
t_capture_begin = time.perf_counter()
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
|
||||
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
|
||||
self.graph_block_tables[0, :] = np.arange(
|
||||
0, max_num_blocks
|
||||
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
output_tensor = torch.zeros(
|
||||
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
||||
)
|
||||
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
|
||||
|
||||
max_num_seqs = self.inference_config.max_batch_size
|
||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
||||
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
sequence_lengths[0] = torch.tensor(
|
||||
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
|
||||
).cuda()
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
if self.verbose:
|
||||
self.logger.info(f"batch size {batch_size} graph capturing")
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=block_tables[:batch_size],
|
||||
sequence_lengths=sequence_lengths[:batch_size],
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
batch_size=batch_size,
|
||||
is_prompts=False,
|
||||
use_cuda_graph=True,
|
||||
high_precision=False,
|
||||
kv_seq_len=sequence_lengths[:batch_size].max().item(),
|
||||
head_dim=head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
input_tokens_ids[:batch_size],
|
||||
output_tensor[:batch_size],
|
||||
input_meta_data,
|
||||
k_caches=k_cache,
|
||||
v_caches=v_cache,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
|
||||
t_capture_end = time.perf_counter()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
"""Verify the input args"""
|
||||
if not isinstance(self.inference_config, InferenceConfig):
|
||||
raise TypeError("Invalid type of inference config provided.")
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
if isinstance(self.model, ModelWrapper):
|
||||
model = self.model.module
|
||||
assert (
|
||||
model.__class__.__name__ in _supported_models.keys()
|
||||
), f"Model {self.model.__class__.__name__} is not supported."
|
||||
|
||||
def enable_spec_dec(
|
||||
self,
|
||||
drafter_model: nn.Module = None,
|
||||
n_spec_tokens: int = None,
|
||||
use_glide_drafter: bool = False,
|
||||
) -> None:
|
||||
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||
|
||||
Args:
|
||||
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
||||
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
||||
If True, the drafter model will be replaced by a glide model.
|
||||
|
||||
```python
|
||||
...
|
||||
engine = InferenceEngine(model, tokenizer, inference_config)
|
||||
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||
engine.generate(...) # Speculative Decoding
|
||||
|
||||
engine.disable_spec_dec()
|
||||
engine.generate(...) # Normal generation
|
||||
|
||||
engine.enable_spec_dec()
|
||||
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
||||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
if n_spec_tokens is not None:
|
||||
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
||||
self.n_spec_tokens = n_spec_tokens
|
||||
if drafter_model is not None:
|
||||
assert isinstance(drafter_model, nn.Module)
|
||||
# overwrite the drafter, if exists
|
||||
self.clear_spec_dec()
|
||||
self.drafter_model = drafter_model
|
||||
self.drafter = Drafter(
|
||||
self.drafter_model,
|
||||
self.tokenizer,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# check if the provided drafter model is compatible with GLIDE structure
|
||||
# when `use_glide_drafter` is set to True
|
||||
if (
|
||||
use_glide_drafter
|
||||
and hasattr(drafter_model, "model")
|
||||
and hasattr(drafter_model.model, "layers")
|
||||
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
||||
):
|
||||
self.use_glide = use_glide_drafter
|
||||
elif use_glide_drafter:
|
||||
self.logger.warning(
|
||||
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
||||
f"but the provided drafter model is not compatible with GLIDE structure."
|
||||
f"Falling back to use the default drafter model (non-GLIDE)."
|
||||
)
|
||||
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||
# using speculative decoding for subsequent generations
|
||||
self.use_spec_dec = True
|
||||
|
||||
def disable_spec_dec(self) -> None:
|
||||
"""Disable using speculative decoding for subsequent generations."""
|
||||
self.request_handler.unset_spec_dec_mode()
|
||||
# set back to the maximum number of tokens to speculate
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def clear_spec_dec(self) -> None:
|
||||
"""Clear relatable structures of speculative decoding, if exist."""
|
||||
if self.use_spec_dec:
|
||||
self.disable_spec_dec()
|
||||
if self.drafter_model or self.drafter:
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
torch.cuda.empty_cache()
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def steps_spec_dec(self) -> List[Sequence]:
|
||||
"""
|
||||
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
||||
with many steps of speculating by a drafter model as well as verifying by a main model.
|
||||
|
||||
Returns:
|
||||
List[Sequence]: finished sequences generated by one step.
|
||||
"""
|
||||
batch = self.request_handler.schedule() # prefill batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
||||
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
|
||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
# append new inputs to the batch, temporarily
|
||||
batch.append_batch_tokens(next_tokens)
|
||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(1)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
while True:
|
||||
# HACK Retrieve the running batch
|
||||
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||
batch = self.request_handler.running_bb # running batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
# 3. Decoding - Drafter model speculates `n` tokens
|
||||
glide_input = None
|
||||
if self.use_glide:
|
||||
glide_input = GlideInput(
|
||||
batch.get_block_table_tensor(),
|
||||
self.k_cache[-1], # use kv cahces of the last layer
|
||||
self.v_cache[-1],
|
||||
batch.get_sequence_lengths(),
|
||||
n_spec_tokens=self.n_spec_tokens,
|
||||
)
|
||||
|
||||
drafter_out = self.drafter.speculate(
|
||||
input_token_ids,
|
||||
self.n_spec_tokens,
|
||||
drafter_past_key_values,
|
||||
glide_input=glide_input,
|
||||
)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
drafter_spec_length = drafter_out.speculated_length
|
||||
|
||||
for next_token_id_spec in next_token_ids_spec:
|
||||
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||
cur_length = batch.seq_lengths[0].item()
|
||||
if already_allocated_kv_len < cur_length:
|
||||
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
||||
already_allocated_kv_len = cur_length
|
||||
|
||||
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||
if drafter_spec_length < batch.num_tokens_to_verify:
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
|
||||
# 5. Compare and process the results
|
||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||
|
||||
# revoke appended tokens for each Sequence in the current batch
|
||||
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||
|
||||
# append the last correct token generated by the main model
|
||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||
|
||||
# trim past key values of the drafter model
|
||||
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||
)
|
||||
|
||||
# prepare inputs for the next round of speculation
|
||||
n = 1 if n_matches < drafter_spec_length else 2
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(n)
|
||||
|
||||
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||
finished_sequences = self.request_handler.update()
|
||||
if len(finished_sequences) > 0:
|
||||
break
|
||||
|
||||
# Reset back the number of speculated tokens of the batch,
|
||||
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
||||
# by the drafter is less than the number of speculated tokens set to the engine.
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
||||
|
||||
return finished_sequences
|
||||
|
||||
def generate(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
|
||||
Args:
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
||||
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
||||
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||
|
||||
with torch.inference_mode():
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(
|
||||
request_ids=request_ids,
|
||||
prompts=prompts,
|
||||
prompts_token_ids=prompts_token_ids,
|
||||
**gen_config_dict,
|
||||
)
|
||||
|
||||
output_seqs_list = []
|
||||
total_tokens_list = []
|
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
self.generation_config_dict = gen_config_dict
|
||||
|
||||
if self.use_spec_dec:
|
||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||
while self.request_handler.check_unfinished_reqs():
|
||||
output_seqs_list += self.steps_spec_dec()
|
||||
else:
|
||||
while self.request_handler.check_unfinished_reqs():
|
||||
output_seqs_list += self.step()
|
||||
|
||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||
|
||||
for seq in output_seqs_list:
|
||||
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
|
||||
|
||||
if return_token_ids:
|
||||
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
|
||||
return output_str, output_tokens_list
|
||||
else:
|
||||
return output_str
|
||||
|
||||
@property
|
||||
def has_prompt_template(self) -> bool:
|
||||
""" """
|
||||
return self.inference_config.prompt_template is not None
|
||||
|
||||
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
|
||||
"""
|
||||
This method will format the input prompt according to the prompt template given to the InferenceConfig.
|
||||
"""
|
||||
assert (
|
||||
self.has_prompt_template
|
||||
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
|
||||
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||
elif isinstance(prompts, str):
|
||||
return self.inference_config.prompt_template.format(input_text=prompts)
|
||||
else:
|
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Add requests.
|
||||
|
||||
Args:
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
"""
|
||||
|
||||
# apply the prompt template to the input prompts
|
||||
|
||||
if self.has_prompt_template and prompts is not None:
|
||||
prompts = self.format_prompt(prompts)
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
# list of torch Tensor
|
||||
if isinstance(prompts_token_ids, list):
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
|
||||
)
|
||||
|
||||
assert (
|
||||
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
|
||||
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
assert len(request_ids) == prompts_num
|
||||
request_id = request_ids[i]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
if prompts == None:
|
||||
prompt = None
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
|
||||
max_length = kwargs.get("max_length", None)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", None)
|
||||
if max_length is None and max_new_tokens is None:
|
||||
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
|
||||
elif max_length is not None:
|
||||
max_new_tokens = max_length - len(prompts_token_ids[i])
|
||||
|
||||
if not self.inference_config.enable_streamingllm:
|
||||
assert (
|
||||
self.inference_config.max_output_len >= max_new_tokens
|
||||
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
||||
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
max_output_len=max_new_tokens,
|
||||
ignore_eos=self.inference_config.ignore_eos,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
||||
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
||||
input_ids = batch.get_1D_inputs()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
if batch.is_prompts:
|
||||
n_tokens = sequence_lengths.sum().item()
|
||||
else:
|
||||
n_tokens = batch.current_batch_size
|
||||
if batch.use_spec_dec:
|
||||
n_tokens = batch.num_tokens_to_verify + 1
|
||||
assert n_tokens == input_ids.size(0)
|
||||
n_tokens = n_tokens * batch.current_batch_size
|
||||
output_tensor = torch.zeros(
|
||||
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
|
||||
batch_token_ids = None
|
||||
if (
|
||||
self.generation_config.repetition_penalty != 1.0
|
||||
or self.generation_config.no_repeat_ngram_size > 0
|
||||
or self.generation_config.forced_eos_token_id is not None
|
||||
):
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph = False
|
||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||
use_cuda_graph = True
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=batch.get_block_table_tensor(),
|
||||
sequence_lengths=sequence_lengths,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
batch_size=batch.current_batch_size,
|
||||
is_prompts=batch.is_prompts,
|
||||
use_cuda_kernel=self.inference_config.use_cuda_kernel,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
high_precision=self.high_precision,
|
||||
kv_seq_len=sequence_lengths.max().item(),
|
||||
head_dim=batch.head_dim,
|
||||
dtype=batch.dtype,
|
||||
use_spec_dec=batch.use_spec_dec,
|
||||
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||
batch_token_ids=batch_token_ids,
|
||||
)
|
||||
|
||||
return input_ids, output_tensor, input_meta_data
|
||||
|
||||
def step(self) -> List[str]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. Get the input, inputinfo and output placeholder from the batchbucket
|
||||
3. Run model to generate the next token
|
||||
4. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||
5. Decode and return finished sequences.
|
||||
|
||||
Returns:
|
||||
List[str]: Decoded finished sequences generated by one step.
|
||||
"""
|
||||
|
||||
batch = self.request_handler.schedule()
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if self.inference_config.enable_streamingllm:
|
||||
updated_block_ids = batch.streamingllm_update_batch(
|
||||
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
||||
)
|
||||
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
||||
|
||||
next_tokens = search_tokens(
|
||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||
)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
|
@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
|||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
||||
from colossalai.inference.struct import RequestStatus, Sequence
|
||||
from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
@ -98,7 +98,46 @@ class RunningList:
|
|||
self._decoding[seq_id] = self._prefill.pop(seq_id)
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
class NaiveRequestHandler:
|
||||
def __init__(self) -> None:
|
||||
self.running_list: List[DiffusionSequence] = []
|
||||
self.waiting_list: List[str] = []
|
||||
|
||||
def _has_waiting(self) -> bool:
|
||||
return any(lst for lst in self.waiting_list)
|
||||
|
||||
def _has_running(self) -> bool:
|
||||
return any(lst for lst in self.running_list)
|
||||
|
||||
def check_unfinished_reqs(self):
|
||||
return self._has_waiting() or self._has_running()
|
||||
|
||||
def add_sequence(self, seq: DiffusionSequence):
|
||||
"""
|
||||
Add the request to waiting list.
|
||||
"""
|
||||
assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
|
||||
self.waiting_list.append(seq)
|
||||
|
||||
def _find_sequence(self, request_id: int) -> DiffusionSequence:
|
||||
"""
|
||||
Find the request by request_id.
|
||||
"""
|
||||
for lst in enumerate(self.waiting_list + self.running_list):
|
||||
for seq in lst:
|
||||
if seq.request_id == request_id:
|
||||
return seq
|
||||
return None
|
||||
|
||||
def schedule(self):
|
||||
ret = None
|
||||
if self._has_waiting:
|
||||
ret = self.waiting_list[0]
|
||||
self.waiting_list = self.waiting_list[1:]
|
||||
return ret
|
||||
|
||||
|
||||
class RequestHandler(NaiveRequestHandler):
|
||||
"""
|
||||
RequestHandler is the core for handling existing requests and updating current batch.
|
||||
During generation process, we call schedule function each iteration to update current batch.
|
||||
|
@ -176,12 +215,12 @@ class RequestHandler:
|
|||
generated_token_size=inference_config.generated_token_size,
|
||||
)
|
||||
|
||||
def _has_running(self) -> bool:
|
||||
return not self.running_bb.is_empty()
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||
|
||||
def _has_waiting(self) -> bool:
|
||||
return any(lst for lst in self.waiting_list)
|
||||
|
||||
def get_kvcache(self):
|
||||
return self.cache_manager.get_kv_cache()
|
||||
|
||||
|
@ -318,7 +357,7 @@ class RequestHandler:
|
|||
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
|
||||
seq.mark_finished()
|
||||
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
def check_unfinished_reqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
def total_requests_in_batch_bucket(self) -> int:
|
||||
|
|
|
@ -34,7 +34,6 @@ def run_server(host, port, event: mp.Event = None):
|
|||
|
||||
|
||||
class RPCInferenceEngine(InferenceEngine):
|
||||
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
|
|
|
@ -42,7 +42,6 @@ logger = get_dist_logger(__name__)
|
|||
|
||||
|
||||
class rpcWorkerService(rpyc.Service):
|
||||
|
||||
"""
|
||||
Execute the computation tasks and manage its own kv cache
|
||||
|
||||
|
|
|
@ -279,9 +279,11 @@ class KVCacheManager:
|
|||
block.add_ref()
|
||||
self._allocate_on_block(
|
||||
block,
|
||||
block.block_size
|
||||
if context_lengths[i] % block.block_size == 0
|
||||
else context_lengths[i].item() % block.block_size,
|
||||
(
|
||||
block.block_size
|
||||
if context_lengths[i] % block.block_size == 0
|
||||
else context_lengths[i].item() % block.block_size
|
||||
),
|
||||
)
|
||||
for block_id in alloc_block_ids:
|
||||
if block_id in alloc_block_ids[last_block_locs]:
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import inspect
|
||||
import types
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DiffusionPipe(nn.Module):
|
||||
"""
|
||||
This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property.
|
||||
"""
|
||||
|
||||
def __init__(self, source_obj) -> None:
|
||||
super(DiffusionPipe, self).__init__()
|
||||
|
||||
for k, v in source_obj.__dict__.items():
|
||||
if isinstance(v, nn.Module):
|
||||
self.add_module(k, v)
|
||||
else:
|
||||
setattr(self, k, v)
|
||||
|
||||
skip_list = ["_execution_device", "to", "device"] # this
|
||||
|
||||
for name, member in inspect.getmembers(source_obj.__class__):
|
||||
if name in skip_list:
|
||||
continue
|
||||
if not name.startswith("__") and not name.endswith("__"):
|
||||
if isinstance(member, property):
|
||||
setattr(self.__class__, name, member)
|
||||
elif inspect.isfunction(member) or inspect.ismethod(member):
|
||||
bound_method = types.MethodType(member, self)
|
||||
setattr(self, name, bound_method)
|
||||
elif not callable(member) and not isinstance(member, property):
|
||||
setattr(self, name, member)
|
||||
elif name == "__call__":
|
||||
bound_method = types.MethodType(member, self)
|
||||
setattr(self, "_forward", bound_method)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
|
||||
Accelerate's module hooks.
|
||||
"""
|
||||
# return self.device
|
||||
return torch.device("cuda")
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
next(self.parameters()).device
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward(*args, **kwargs)
|
|
@ -0,0 +1,220 @@
|
|||
# Code adapted from:
|
||||
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
|
||||
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
|
||||
ASPECT_RATIO_256_BIN,
|
||||
ASPECT_RATIO_512_BIN,
|
||||
ASPECT_RATIO_1024_BIN,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .diffusion import DiffusionPipe
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def pixart_alpha_forward(
|
||||
self: DiffusionPipe,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 20,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
guidance_scale: float = 4.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
use_resolution_binning: bool = True,
|
||||
max_sequence_length: int = 120,
|
||||
**kwargs,
|
||||
) -> PIL.Image:
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
if use_resolution_binning:
|
||||
if self.transformer.config.sample_size == 128:
|
||||
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
||||
elif self.transformer.config.sample_size == 64:
|
||||
aspect_ratio_bin = ASPECT_RATIO_512_BIN
|
||||
elif self.transformer.config.sample_size == 32:
|
||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||
else:
|
||||
raise ValueError("Invalid sample size")
|
||||
orig_height, orig_width = height, width
|
||||
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_steps,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.1 Prepare micro-conditions.
|
||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||
if self.transformer.config.sample_size == 128:
|
||||
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
||||
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
||||
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
|
||||
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
resolution = torch.cat([resolution, resolution], dim=0)
|
||||
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
|
||||
|
||||
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
current_timestep = t
|
||||
if not torch.is_tensor(current_timestep):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = latent_model_input.device.type == "mps"
|
||||
if isinstance(current_timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
||||
elif len(current_timestep.shape) == 0:
|
||||
current_timestep = current_timestep[None].to(latent_model_input.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=current_timestep,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# learned sigma
|
||||
if self.transformer.config.out_channels // 2 == latent_channels:
|
||||
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
||||
else:
|
||||
noise_pred = noise_pred
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
if num_inference_steps == 1:
|
||||
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
if use_resolution_binning:
|
||||
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
# self.maybe_free_model_hooks()
|
||||
|
||||
return image
|
|
@ -0,0 +1,178 @@
|
|||
# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
||||
|
||||
from .diffusion import DiffusionPipe
|
||||
|
||||
|
||||
# TODO(@lry89757) temporarily image, please support more return output
|
||||
@torch.no_grad()
|
||||
def sd3_forward(
|
||||
self: DiffusionPipe,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_3=prompt_3,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
clip_skip=self.clip_skip,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
return image
|
|
@ -1,16 +1,22 @@
|
|||
from .glide_llama import GlideLlamaModelPolicy
|
||||
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
|
||||
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
|
||||
from .pixart_alpha import PixArtAlphaInferPolicy
|
||||
from .stablediffusion3 import StableDiffusion3InferPolicy
|
||||
|
||||
model_policy_map = {
|
||||
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
|
||||
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
|
||||
"glide_llama": GlideLlamaModelPolicy,
|
||||
"StableDiffusion3Pipeline": StableDiffusion3InferPolicy,
|
||||
"PixArtAlphaPipeline": PixArtAlphaInferPolicy,
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"NoPaddingLlamaModelInferPolicy",
|
||||
"NoPaddingBaichuanModelInferPolicy",
|
||||
"GlideLlamaModelPolicy",
|
||||
"StableDiffusion3InferPolicy",
|
||||
"PixArtAlphaInferPolicy",
|
||||
"model_polic_map",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
from torch import nn
|
||||
|
||||
from colossalai.inference.config import RPC_PARAM
|
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
|
||||
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = {}
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
|
||||
)
|
||||
return policy
|
||||
|
||||
def preprocess(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def to_rpc_param(self) -> str:
|
||||
return __class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def from_rpc_param() -> "PixArtAlphaInferPolicy":
|
||||
return PixArtAlphaInferPolicy()
|
|
@ -0,0 +1,34 @@
|
|||
from torch import nn
|
||||
|
||||
from colossalai.inference.config import RPC_PARAM
|
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
|
||||
class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = {}
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
|
||||
)
|
||||
return policy
|
||||
|
||||
def preprocess(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def to_rpc_param(self) -> str:
|
||||
return __class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def from_rpc_param() -> "StableDiffusion3InferPolicy":
|
||||
return StableDiffusion3InferPolicy()
|
|
@ -2,6 +2,7 @@ import enum
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
from colossalai.inference.config import DiffusionGenerationConfig
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
@ -46,6 +47,17 @@ class RequestStatus(enum.Enum):
|
|||
return status == RequestStatus.WAITING
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionSequence:
|
||||
"""
|
||||
parameters for diffusion
|
||||
"""
|
||||
|
||||
request_id: int
|
||||
prompt: str
|
||||
generation_config: DiffusionGenerationConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sequence:
|
||||
"""Store information of input sequence.
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from torch import nn
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
@ -158,3 +161,36 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
|
|||
except ImportError:
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
return False
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
DIFFUSION_MODEL = "Diffusion Model"
|
||||
LLM = "Large Language Model (LLM)"
|
||||
UNKNOWN = "Unknown Model Type"
|
||||
|
||||
|
||||
def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
|
||||
if isinstance(model_or_path, DiffusionPipeline):
|
||||
return ModelType.DIFFUSION_MODEL
|
||||
elif isinstance(model_or_path, nn.Module):
|
||||
return ModelType.LLM
|
||||
elif isinstance(model_or_path, str):
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
return ModelType.LLM
|
||||
except:
|
||||
"""
|
||||
model type is not `ModelType.LLM`
|
||||
"""
|
||||
|
||||
try:
|
||||
DiffusionPipeline.load_config(model_or_path)
|
||||
return ModelType.DIFFUSION_MODEL
|
||||
except:
|
||||
"""
|
||||
model type is not `ModelType.DIFFUSION_MODEL`
|
||||
"""
|
||||
else:
|
||||
return ModelType.UNKNOWN
|
||||
|
|
|
@ -3,6 +3,12 @@
|
|||
|
||||
import os
|
||||
|
||||
# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when overlapping communication and computation,
|
||||
# the order of of kernel launches on GPUs are the same as on the CPU so that comm is launched first.
|
||||
# see https://github.com/NVIDIA/Megatron-LM/issues/533
|
||||
# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
|
||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
|
|
@ -138,9 +138,7 @@ class Initializer_2D(ProcessGroupInitializer):
|
|||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
|
||||
|
||||
assert (
|
||||
self.tensor_parallel_size == self.summa_dim**2
|
||||
), "2D summa dim should equal to tensor parallel size ^ 0.5"
|
||||
assert self.tensor_parallel_size == self.summa_dim**2, "2D summa dim should equal to tensor parallel size ^ 0.5"
|
||||
_check_summa_env_var(self.summa_dim)
|
||||
|
||||
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)
|
||||
|
|
|
@ -54,7 +54,6 @@ class RequestTracker:
|
|||
|
||||
|
||||
class Async_Engine:
|
||||
|
||||
"""
|
||||
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
|
||||
Background loop: inference reqs in waiting list (Listen)
|
||||
|
|
|
@ -118,16 +118,16 @@ class Batch:
|
|||
|
||||
class BatchTokenIdOut:
|
||||
def __init__(self):
|
||||
self.reqs_infs: List[
|
||||
Tuple[str, int, Dict, bool, bool]
|
||||
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
|
||||
self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = (
|
||||
[]
|
||||
) # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
|
||||
|
||||
|
||||
class BatchStrOut:
|
||||
def __init__(self):
|
||||
self.reqs_infs: List[
|
||||
Tuple[str, str, Dict, bool, bool]
|
||||
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
|
||||
self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = (
|
||||
[]
|
||||
) # [req_id, token_str, gen_metadata, finished_state, abort_state]
|
||||
|
||||
|
||||
class AbortReq:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
|
|
@ -14,6 +14,7 @@ class BatchInferState:
|
|||
Information to be passed and used for a batch of inputs during
|
||||
a single model forward
|
||||
"""
|
||||
|
||||
batch_size: int
|
||||
max_len_in_batch: int
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ of the ModelTC/lightllm GitHub repository
|
|||
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
|
||||
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers.utils import logging
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
|
|
@ -81,7 +81,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
|
|
@ -1,20 +1,5 @@
|
|||
from .checkpoint import MoECheckpointIO
|
||||
from .experts import MLPExperts
|
||||
from .layers import SparseMLP, apply_load_balance
|
||||
from .manager import MOE_MANAGER
|
||||
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||
|
||||
__all__ = [
|
||||
"MLPExperts",
|
||||
"MoeRouter",
|
||||
"Top1Router",
|
||||
"Top2Router",
|
||||
"TopKRouter",
|
||||
"NormalNoiseGenerator",
|
||||
"UniformNoiseGenerator",
|
||||
"SparseMLP",
|
||||
"MoECheckpointIO",
|
||||
"MOE_MANAGER",
|
||||
"apply_load_balance",
|
||||
]
|
||||
|
|
|
@ -1,792 +0,0 @@
|
|||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
StateDictSharder,
|
||||
gather_distributed_param,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
is_safetensors_available,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_state_dict_into_model,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import (
|
||||
get_dp_group,
|
||||
get_dp_rank,
|
||||
get_dp_size,
|
||||
get_ep_group,
|
||||
get_ep_rank,
|
||||
get_ep_size,
|
||||
is_moe_tensor,
|
||||
)
|
||||
|
||||
|
||||
class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
def __init__(
|
||||
self,
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
) -> None:
|
||||
assert zero_stage in [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
], f"zero_stage should be 0 or 1 or 2, got {zero_stage}"
|
||||
super().__init__(dp_group, pp_group, tp_group, zero_stage)
|
||||
self.parallel = MOE_MANAGER.parallel
|
||||
|
||||
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
|
||||
"""
|
||||
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
|
||||
"""
|
||||
for name, param in state_dict.items():
|
||||
if ".experts." in name:
|
||||
if name in dict(model.named_parameters()):
|
||||
model_param = dict(model.named_parameters())[name]
|
||||
if is_moe_tensor(model_param):
|
||||
ep_rank = get_ep_rank(model_param)
|
||||
ep_size = get_ep_size(model_param)
|
||||
expert_num = param.shape[0] // ep_size
|
||||
assert param.shape[0] % ep_size == 0
|
||||
param = param[ep_rank * expert_num : (ep_rank + 1) * expert_num]
|
||||
state_dict[name] = param
|
||||
dist.barrier()
|
||||
return state_dict
|
||||
|
||||
def _model_sharder(
|
||||
self,
|
||||
state_dict: nn.Module,
|
||||
prefix: str = "",
|
||||
keep_vars: bool = False,
|
||||
size_per_shard: int = 1024,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
# An internel method that breaks state_dict of model into shards within limited size.
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
|
||||
for name, param in state_dict.items():
|
||||
if param is None:
|
||||
continue
|
||||
# Gather tensor pieces when using tensor parallel.
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None:
|
||||
state_dict = torch.load(checkpoint)
|
||||
state_dict = self.pre_load_model(model, state_dict)
|
||||
model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False)
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
|
||||
# Check whether the checkpoint uses safetensors.
|
||||
use_safetensors = False
|
||||
if "safetensors" in checkpoint_index_file.name:
|
||||
use_safetensors = True
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
strict = False
|
||||
|
||||
# Load params & buffers to model.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
|
||||
def _load(name: str):
|
||||
if name not in weight_map:
|
||||
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
|
||||
filename = weight_map[name]
|
||||
|
||||
# If this param/buffer has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
return
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
state_dict = self.pre_load_model(model, state_dict)
|
||||
missing_keys = []
|
||||
|
||||
load_state_dict_into_model(
|
||||
model,
|
||||
state_dict,
|
||||
missing_keys=missing_keys,
|
||||
strict=strict,
|
||||
load_sub_module=True,
|
||||
)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Load parameters.
|
||||
for name, _ in model.named_parameters():
|
||||
_load(name)
|
||||
|
||||
if self.verbose:
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def pre_save_model(self, model: nn.Module) -> dict:
|
||||
state_dict = model.state_dict()
|
||||
for name, param in model.named_parameters():
|
||||
if ".experts." in name and is_moe_tensor(param):
|
||||
ep_group = get_ep_group(param)
|
||||
ep_rank = get_ep_rank(param)
|
||||
ep_size = get_ep_size(param)
|
||||
dp_rank = get_dp_rank(param)
|
||||
if dp_rank == 0:
|
||||
param = param.data.cuda()
|
||||
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
||||
# gather param from every ep rank
|
||||
dist.all_gather(all_param, param, group=ep_group)
|
||||
if ep_rank == 0:
|
||||
all_param = torch.cat(all_param, dim=0)
|
||||
state_dict[name] = all_param.cpu()
|
||||
if self.pp_size > 1:
|
||||
if self.dp_rank == 0:
|
||||
out = [None for _ in range(self.pp_size)]
|
||||
dist.all_gather_object(out, state_dict, group=self.pp_group)
|
||||
if self.pp_rank == 0:
|
||||
new_state_dict = {}
|
||||
for o in out:
|
||||
new_state_dict.update(o)
|
||||
state_dict = new_state_dict
|
||||
dist.barrier()
|
||||
return state_dict
|
||||
|
||||
def save_unsharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
use_safetensors: bool,
|
||||
):
|
||||
state_dict = self.pre_save_model(model)
|
||||
if dist.get_rank() == 0:
|
||||
torch.save(state_dict, checkpoint)
|
||||
dist.barrier()
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save sharded model checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
|
||||
- Multiple files that store state tensors of models.
|
||||
The filenames are in the form of "pytorch_model.<prefix>-000XX.bin"
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model on local device to be saved.
|
||||
checkpoint (str): Checkpointing path which should be a directory path.
|
||||
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
|
||||
prefix (str, optional): Perfix of file to save. Defaults to None.
|
||||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict = self.pre_save_model(model)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard)
|
||||
|
||||
# Devices along the same dp_group share the same copies of model.
|
||||
# So only let the device with dp_rank == 0 save the model.
|
||||
if self.dp_rank != 0:
|
||||
return
|
||||
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose:
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
dist.barrier()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for optimizer loading/saving implementation
|
||||
# ========================================================
|
||||
|
||||
def pre_load_optim(
|
||||
self,
|
||||
state: OrderedDict,
|
||||
working_param,
|
||||
current_shape: torch.Size,
|
||||
original_shape: torch.Size,
|
||||
device: torch.device,
|
||||
inplace: bool,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With complete optimizer states of a specific parameter loaded from checkpoint,
|
||||
slice out the sharded optimizer states kept by current device.
|
||||
|
||||
Args:
|
||||
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
|
||||
current_shape (torch.Size): The size of parameter after sharding.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
device (torch.device): The destination device of loaded optimizer states.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
|
||||
Returns:
|
||||
OrderedDict: The sharded optimizer state of the given parameter.
|
||||
"""
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
is_moe_tensor_flag = is_moe_tensor(working_param)
|
||||
if is_moe_tensor_flag:
|
||||
ep_rank = get_ep_rank(working_param)
|
||||
ep_size = get_ep_size(working_param)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
if is_moe_tensor_flag:
|
||||
with torch.no_grad():
|
||||
expert_num = v.shape[0] // ep_size
|
||||
assert v.shape[0] % ep_size == 0
|
||||
v = v[ep_rank * expert_num : (ep_rank + 1) * expert_num]
|
||||
else:
|
||||
# Shard state along data parallel group when using Zero.
|
||||
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.dp_size
|
||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
|
||||
):
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
|
||||
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
|
||||
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
||||
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
||||
id_map = {}
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||
id_map[param_id] = param
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
|
||||
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(
|
||||
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory."
|
||||
)
|
||||
saved_groups = torch.load(param_group_path)
|
||||
|
||||
updated_groups = []
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
# obtain updated param group
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
|
||||
updated_groups.append(new_pg)
|
||||
# ep param group
|
||||
if len(optimizer.optim.param_groups) > len(saved_groups):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||
if param_id not in weight_map:
|
||||
continue
|
||||
filename = weight_map[param_id]
|
||||
|
||||
# If this param's states has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
continue
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for pid, state in list(state_dict.items()):
|
||||
if pid in id_map:
|
||||
param = id_map[pid]
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif (
|
||||
hasattr(optimizer, "moe_master_to_working_map")
|
||||
and id(param) in optimizer.moe_master_to_working_map
|
||||
):
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
working_param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device="cpu",
|
||||
inplace=True,
|
||||
)
|
||||
state_dict[pid] = sharded_state
|
||||
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
dist.barrier()
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
"""
|
||||
Load optimizer from a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the checkpoint file.
|
||||
"""
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
):
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
if id(working_param) in optimizer.param_info["param2id"]:
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
else:
|
||||
None
|
||||
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
|
||||
# Load param_groups.
|
||||
updated_groups = []
|
||||
saved_groups = state_dict["param_groups"]
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
|
||||
updated_groups.append(new_pg)
|
||||
# ep extra group
|
||||
if MOE_MANAGER.parallel == "EP":
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1][
|
||||
"params"
|
||||
] # Only keep the parameters kept by current pipeline stage.
|
||||
for param in new_pg["params"]:
|
||||
param.data = param.data.to(torch.float32)
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
id_map = {}
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
if param_id is not None:
|
||||
id_map[param_id] = param
|
||||
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
device = param.device
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True,
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
dist.barrier()
|
||||
|
||||
def pre_save_optim(
|
||||
self,
|
||||
state: OrderedDict,
|
||||
param: torch.Tensor,
|
||||
inplace: bool,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
|
||||
Args:
|
||||
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
|
||||
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
dp_group (ProcessGroup): The process group of data parallel.
|
||||
tp_group (ProcessGroup): The process group of tensor parallel.
|
||||
use_zero (bool): Whether Zero is used.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
|
||||
|
||||
Returns:
|
||||
OrderedDict: The complete optimizer state of given parameter.
|
||||
"""
|
||||
if is_moe_tensor(param):
|
||||
moe_dp_group = get_dp_group(param)
|
||||
moe_dp_size = get_dp_size(param)
|
||||
moe_ep_group = get_ep_group(param)
|
||||
moe_ep_size = get_ep_size(param)
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# moe param
|
||||
if is_moe_tensor(param):
|
||||
# dp gather
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=moe_dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
# ep gather
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(moe_ep_size)]
|
||||
dist.all_gather(gather_tensor, v, group=moe_ep_group)
|
||||
v = torch.cat(gather_tensor, dim=0)
|
||||
else:
|
||||
# global dp
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dist.get_world_size(self.dp_group))]
|
||||
dist.all_gather(gather_tensor, v, group=self.dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def _optimizer_sharder(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
param_id = param_info["param2id"][id(working_param)]
|
||||
state_ = self.pre_save_optim(
|
||||
state,
|
||||
working_param,
|
||||
inplace=False,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files that store state tensors of optimizers.
|
||||
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
|
||||
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
|
||||
checkpoint (str): Path to save optimizer state_dict
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used
|
||||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Devices along the same dp_group share the same copies of states when zero is not used.
|
||||
# In this case only let the device with dp_rank == 0 save the model.
|
||||
if not self.use_zero and self.dp_rank != 0:
|
||||
return
|
||||
|
||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
||||
state_dict_shard = self._optimizer_sharder(
|
||||
optimizer,
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.dp_rank == 0 and self.tp_rank == 0
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
# Store param groups.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
assert (
|
||||
self.dp_rank == 0 and self.tp_rank == 0
|
||||
), "The saving process should have both dp_rank and tp_rank as 0."
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
return
|
||||
|
||||
dist.barrier(self.pp_group)
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.pp_rank == 0:
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
for filename in os.listdir(tmp_index_file_folder):
|
||||
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
|
||||
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
|
||||
for param_id, state_filename in stage_index_file.weight_map.items():
|
||||
final_index_file.append_weight_map(param_id, state_filename)
|
||||
|
||||
# Store param groups.
|
||||
final_index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer state dict to a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
|
||||
checkpoint (str): Path to save optimizer state_dict.
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
|
||||
# optimizer states of parameters kept by local device('s pipeline stage)
|
||||
local_states = dict()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
# working param is needed for obtaining correct param_id
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
# gather complete state from tp shards & dp shards
|
||||
param_id = optimizer.param_info["param2id"][id(working_param)]
|
||||
local_states[param_id] = self.pre_save_optim(
|
||||
state,
|
||||
working_param,
|
||||
inplace=False,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
else:
|
||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||
states_list = [None for _ in range(self.pp_size)]
|
||||
dist.barrier(self.pp_group)
|
||||
dist.all_gather_object(states_list, local_states, self.pp_group)
|
||||
|
||||
# Only the master rank do the saving.
|
||||
if self.coordinator.is_master():
|
||||
state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
|
||||
for _states in states_list:
|
||||
state_dict["state"].update(_states)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
dist.barrier()
|
|
@ -7,8 +7,8 @@ from torch import Tensor, nn
|
|||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.moe.experts import MLPExperts
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.shardformer.layer.moe import MLPExperts
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
|
@ -292,7 +292,7 @@ class LoadBalancer:
|
|||
exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"]
|
||||
exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"]
|
||||
else:
|
||||
master_weight_ptr = optim._param_store.working_to_master_param[id(weight)]
|
||||
master_weight_ptr = optim.working_to_master_param[id(weight)]
|
||||
working_weight_ptr = weight
|
||||
exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"]
|
||||
exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"]
|
||||
|
@ -344,7 +344,7 @@ class LoadBalancer:
|
|||
# gate optim should be obtained first
|
||||
gate_shape = self.gate.shape
|
||||
# get master weight and optim
|
||||
master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)]
|
||||
master_gate_weight = optim.working_to_master_param[id(self.gate)]
|
||||
gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"]
|
||||
gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"]
|
||||
# gather
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
import torch.nn as nn
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
||||
|
||||
class MoeCrossEntropyLoss(_Loss):
|
||||
r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
|
||||
|
||||
Args:
|
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
|
||||
|
||||
The ``args`` and ``kwargs`` should include parameters below:
|
||||
::
|
||||
|
||||
weight (Tensor, optional)
|
||||
size_average (bool, optional)
|
||||
ignore_index (int, optional)
|
||||
reduce (bool, optional)
|
||||
reduction (str, optional)
|
||||
label_smoothing (float, optional)
|
||||
|
||||
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
|
||||
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.loss = nn.CrossEntropyLoss(*args, **kwargs)
|
||||
self.aux_weight = aux_weight
|
||||
|
||||
def forward(self, *args):
|
||||
"""
|
||||
The ``args`` should at least include parameters below:
|
||||
::
|
||||
|
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||
|
||||
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
|
||||
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
|
||||
"""
|
||||
main_loss = self.loss(*args)
|
||||
aux_loss = MOE_MANAGER.get_loss()
|
||||
return main_loss + self.aux_weight * aux_loss
|
||||
|
||||
|
||||
class MoeLoss(_Loss):
|
||||
"""A wrapper class for any loss module to add with auxiliary loss.
|
||||
|
||||
Args:
|
||||
aux_weight (float): Weight of auxiliary loss in total loss.
|
||||
loss_fn (``Callable``): Loss function.
|
||||
args (list): Args in loss function.
|
||||
kwargs (dict): Kwargs in loss function
|
||||
"""
|
||||
|
||||
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.loss_fn = loss_fn(*args, **kwargs)
|
||||
self.aux_weight = aux_weight
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
The ``args`` and ``kwargs`` should at least include parameters below:
|
||||
::
|
||||
|
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||
|
||||
Note:
|
||||
The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
|
||||
"""
|
||||
main_loss = self.loss_fn(*args, **kwargs)
|
||||
aux_loss = MOE_MANAGER.get_loss()
|
||||
return main_loss + self.aux_weight * aux_loss
|
|
@ -1,466 +0,0 @@
|
|||
import math
|
||||
from abc import ABC
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe._operation import moe_cumsum
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
||||
|
||||
class MoeRouter(nn.Module, ABC):
|
||||
"""Base class for all MoE routers.
|
||||
Args:
|
||||
k_value (int): The value of top_k.
|
||||
capacity_factor_train (float): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float): Capacity factor in routing of evaluation.
|
||||
min_capacity (int): The minimum number of the capacity of each expert.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k_value: int,
|
||||
capacity_factor_train: float,
|
||||
capacity_factor_eval: float,
|
||||
min_capacity: int,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
use_kernel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.k_value = k_value
|
||||
self.capacity_factor_train = capacity_factor_train
|
||||
self.capacity_factor_eval = capacity_factor_eval
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_func = noisy_func
|
||||
self.drop_tks = drop_tks
|
||||
self._aux_loss = None
|
||||
self._z_loss = None
|
||||
self.use_kernel = use_kernel
|
||||
|
||||
def get_capacity(self, num_tokens, num_experts, ep_group=None):
|
||||
if ep_group is not None:
|
||||
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
|
||||
dist.all_reduce(num_tokens_tensor, group=ep_group)
|
||||
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
return int(capacity)
|
||||
|
||||
def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None:
|
||||
"""Computes auxiliary load balancing loss as in Switch Transformer.
|
||||
|
||||
See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
|
||||
implements the loss function presented in equations (4) - (6). It aims to
|
||||
penalize those cases where the routing between experts is unbalanced.
|
||||
|
||||
Args:
|
||||
router_probs: Probability assigned to each expert per token. Shape:
|
||||
<float32>[num_groups, tokens_per_group, num_experts].
|
||||
expert_indices: <int>[num_groups, tokens_per_group, num_selected_experts]
|
||||
indices identifying the top num_selected_experts for a given token.
|
||||
"""
|
||||
assert self._aux_loss is None
|
||||
if router_probs.dim() == expert_indices.dim() == 2:
|
||||
router_probs = router_probs.unsqueeze(0)
|
||||
expert_indices = expert_indices.unsqueeze(0)
|
||||
assert (
|
||||
router_probs.dim() == expert_indices.dim() == 3
|
||||
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
|
||||
|
||||
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
||||
expert_mask = F.one_hot(expert_indices, num_experts)
|
||||
# For a given token, determine if it was routed to a given expert.
|
||||
# Shape: [num_groups, tokens_per_group, num_experts]
|
||||
expert_mask = expert_mask.max(dim=-2)[0]
|
||||
|
||||
tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
|
||||
router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
|
||||
aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
|
||||
self._aux_loss = aux_loss
|
||||
|
||||
def set_z_loss(self, router_logits: torch.Tensor):
|
||||
"""Compute router z-loss.
|
||||
|
||||
The router z-loss was introduced in Designing Effective Sparse Expert Models
|
||||
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
|
||||
small in an effort to improve stability.
|
||||
|
||||
Args:
|
||||
router_logits: <float>[num_groups, tokens_per_group, num_experts] router logits.
|
||||
"""
|
||||
assert self._z_loss is None
|
||||
if router_logits.dim() == 2:
|
||||
router_logits = router_logits.unsqueeze(0)
|
||||
assert router_logits.dim() == 3, "router_logits must be 3D tensor"
|
||||
num_groups, tokens_per_group, _ = router_logits.shape
|
||||
log_z = torch.logsumexp(router_logits, dim=-1)
|
||||
z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group)
|
||||
self._z_loss = z_loss
|
||||
|
||||
def pop_router_loss(self) -> torch.Tensor:
|
||||
assert self._aux_loss is not None
|
||||
MOE_MANAGER.add_loss(self._aux_loss, self._z_loss)
|
||||
self._aux_loss = None
|
||||
self._z_loss = None
|
||||
|
||||
|
||||
class Top1Router(MoeRouter):
|
||||
"""Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
|
||||
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
|
||||
function can be found in the paper about Switch Transformer of Google.
|
||||
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert.
|
||||
select_policy (str, optional): The policy about tokens selection.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
select_policy: str = "first",
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=1,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
self.select_policy = select_policy
|
||||
assert select_policy in {"first", "random"}
|
||||
if select_policy == "random":
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
|
||||
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
|
||||
).rsample
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
use_kernel: bool = False,
|
||||
ep_group: Optional[ProcessGroup] = None,
|
||||
use_loss: bool = False,
|
||||
use_norm: bool = False,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||
|
||||
Returns:
|
||||
1. use_kernel is False:
|
||||
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
2. use_kernel is True:
|
||||
...
|
||||
"""
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
probs = F.softmax(inputs, dim=-1)
|
||||
num_experts = probs.size(-1)
|
||||
num_tokens = inputs.size(0)
|
||||
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
||||
|
||||
top1_idx = torch.argmax(inputs, dim=-1)
|
||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
# calculate router loss
|
||||
self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
|
||||
if not self.training and not self.drop_tks and ep_group is not None:
|
||||
max_num = torch.max(torch.sum(mask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
if self.select_policy == "random":
|
||||
rand_mask = mask * self.uniform(mask.shape)
|
||||
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
||||
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
||||
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
|
||||
elif self.select_policy == "first":
|
||||
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
|
||||
mask = mask * torch.lt(ranks, capacity)
|
||||
else:
|
||||
raise NotImplementedError("Not support such select policy yet.")
|
||||
|
||||
ranks = torch.sum(mask * ranks, dim=-1)
|
||||
used_capacity = mask.sum(dim=0)
|
||||
|
||||
if use_kernel:
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
mask = torch.stack([mask], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
||||
return used_capacity, probs, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
ranks = F.one_hot(ranks, num_classes=capacity)
|
||||
weight = mask * probs.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return used_capacity, combine_weights, sec_mask, probs
|
||||
|
||||
|
||||
class Top2Router(MoeRouter):
|
||||
"""Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
|
||||
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
|
||||
function can be found in the paper about ViT-MoE.
|
||||
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
use_kernel: bool = False,
|
||||
ep_group: Optional[ProcessGroup] = None,
|
||||
use_norm: bool = False,
|
||||
use_loss: bool = True,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||
|
||||
Returns:
|
||||
1. use_kernel is False:
|
||||
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
2. use_kernel is True:
|
||||
...
|
||||
"""
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
probs = F.softmax(inputs, dim=-1)
|
||||
if use_norm:
|
||||
routing_weights, _ = torch.topk(probs, 2, dim=-1)
|
||||
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
num_experts = probs.size(-1)
|
||||
num_tokens = inputs.size(0)
|
||||
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
||||
|
||||
top1_idx = torch.argmax(probs, dim=-1)
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
|
||||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
cmask = mask1 + mask2 # loss: [s, e]
|
||||
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
||||
|
||||
# calculate loss
|
||||
if use_loss:
|
||||
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
||||
self.set_aux_loss(probs, expert_indices, num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
|
||||
if not self.training and not self.drop_tks and ep_group is not None:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
|
||||
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
|
||||
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
||||
|
||||
mask1 *= torch.lt(rank1, capacity)
|
||||
mask2 *= torch.lt(rank2, capacity)
|
||||
used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0)
|
||||
|
||||
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
||||
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask1 = torch.sum(mask1, dim=-1)
|
||||
mask2 = torch.sum(mask2, dim=-1)
|
||||
|
||||
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
||||
|
||||
return used_capacity, probs, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
"""
|
||||
The following code is equivalent to:
|
||||
|
||||
```
|
||||
weight1 = mask1 * probs.type_as(inputs)
|
||||
weight2 = mask2 * probs.type_as(inputs)
|
||||
rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
||||
rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
||||
|
||||
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
||||
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
||||
cb_weight = cb_weight1 + cb_weight2
|
||||
sec_mask = cb_weight.bool()
|
||||
```
|
||||
"""
|
||||
|
||||
weight1 = mask1 * probs.type_as(inputs)
|
||||
weight2 = mask2 * probs.type_as(inputs)
|
||||
|
||||
cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
|
||||
sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
|
||||
indices = torch.arange(0, inputs.shape[0], device=inputs.device)
|
||||
cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
|
||||
cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
|
||||
sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
|
||||
sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]
|
||||
|
||||
return used_capacity, cb_weight, sec_mask
|
||||
|
||||
|
||||
class TopKRouter(MoeRouter):
|
||||
"""Masked matmul router using tokens choose top-k experts assignment.
|
||||
|
||||
NOTE: this is modified from flaxformer.
|
||||
This router uses the same mechanism as in Switch Transformer
|
||||
(https://arxiv.org/abs/2101.03961) and V-MoE
|
||||
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
|
||||
sorted by router_probs and then routed to their choice of expert until the
|
||||
expert's expert_capacity is reached. There is no guarantee that each token is
|
||||
processed by an expert, or that each expert receives at least one token.
|
||||
|
||||
Attributes:
|
||||
num_selected_experts: Maximum number of experts to which each token is
|
||||
routed. Tokens may be routed to fewer experts if particular experts are
|
||||
oversubscribed / reach capacity.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_selected_experts: int,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
router_probs: torch.Tensor,
|
||||
expert_capacity: int,
|
||||
) -> Tuple:
|
||||
"""Computes masks for the top-k experts per token.
|
||||
|
||||
Args:
|
||||
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
|
||||
probabilities used to determine the routing of tokens to the experts.
|
||||
|
||||
Returns:
|
||||
Dispatch and combine arrays for routing with masked matmuls.
|
||||
"""
|
||||
# TODO: FIXME: add parallel group
|
||||
num_groups, _, num_experts = router_probs.shape
|
||||
|
||||
# Top-k router probability and corresponding expert indices for each token.
|
||||
# Shape: [num_groups, tokens_per_group, num_selected_experts].
|
||||
expert_gate, expert_index = torch.topk(router_probs, self.k_value)
|
||||
|
||||
self.set_aux_loss(router_probs, expert_index, num_experts)
|
||||
self.pop_router_loss()
|
||||
|
||||
# Make num_selected_experts the leading axis to ensure that top-1 choices
|
||||
# have priority over top-2 choices, which have priority over top-3 choices,
|
||||
# etc.
|
||||
expert_index = torch.transpose(expert_index, 1, 2)
|
||||
# Shape: [num_groups, num_selected_experts * tokens_per_group]
|
||||
expert_index = expert_index.reshape(num_groups, -1)
|
||||
|
||||
# Create mask out of indices.
|
||||
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
|
||||
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
|
||||
|
||||
# Experts have a fixed capacity that we cannot exceed. A token's priority
|
||||
# within the expert's buffer is given by the masked, cumulative capacity of
|
||||
# its target expert.
|
||||
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
|
||||
token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
|
||||
# Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
|
||||
token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
|
||||
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
||||
token_priority = torch.transpose(token_priority, 1, 2)
|
||||
# For each token, across all selected experts, select the only non-negative
|
||||
# (unmasked) priority. Now, for group G routing to expert E, token T has
|
||||
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
|
||||
# is its targeted expert.
|
||||
# Shape: [num_groups, tokens_per_group, num_experts].
|
||||
token_priority = torch.max(token_priority, dim=2)[0]
|
||||
|
||||
# Token T can only be routed to expert E if its priority is positive and
|
||||
# less than the expert capacity. One-hot matrix will ignore indices outside
|
||||
# the range [0, expert_capacity).
|
||||
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
|
||||
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
|
||||
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
|
||||
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
|
||||
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
|
||||
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
|
||||
|
||||
# The combine array will be used for combining expert outputs, scaled by the
|
||||
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
|
||||
# expert_capacity].
|
||||
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
|
||||
|
||||
return combine_array, dispatch_mask
|
||||
|
||||
|
||||
def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter:
|
||||
if not grouped:
|
||||
if top_k == 1:
|
||||
return Top1Router
|
||||
elif top_k == 2:
|
||||
return Top2Router
|
||||
else:
|
||||
raise NotImplementedError("top_k > 2 is not supported yet")
|
||||
else:
|
||||
return TopKRouter
|
|
@ -6,10 +6,11 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed.distributed_c10d import get_process_group_ranks
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
|
||||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
|
@ -145,7 +146,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
|
|||
if not is_moe_tensor(param):
|
||||
ep_size = 1 # set ep_size to 1 for dp parameters
|
||||
else:
|
||||
ep_size = get_ep_size(param)
|
||||
ep_size = dist.get_world_size(param.ep_group)
|
||||
if ep_size not in epsize_param_dict:
|
||||
epsize_param_dict[ep_size] = []
|
||||
epsize_param_dict[ep_size].append(param)
|
||||
|
@ -170,8 +171,8 @@ def sync_moe_model_param(model: nn.Module):
|
|||
# When ep_size = world_size, communication is not needed
|
||||
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
|
||||
for param in param_dict[ep_size]:
|
||||
src_rank = get_dp_group_ranks(param)[0]
|
||||
dist.broadcast(param, src=src_rank, group=get_dp_group(param))
|
||||
src_rank = get_process_group_ranks(param.dp_group)[0]
|
||||
dist.broadcast(param, src=src_rank, group=param.dp_group)
|
||||
|
||||
|
||||
def set_moe_args(config: Any, args: dict):
|
||||
|
|
|
@ -1,17 +1,25 @@
|
|||
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
|
||||
|
||||
import importlib.metadata
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging.version import Version
|
||||
|
||||
from .bnb_config import BnbQuantizationConfig
|
||||
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
|
||||
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
|
||||
try:
|
||||
# in case lower version of bitsandbytes does not have __version__ attribute
|
||||
BNB_VERSION = Version(bnb.__version__)
|
||||
except AttributeError:
|
||||
BNB_VERSION = Version(importlib.metadata.version("bitsandbytes"))
|
||||
|
||||
IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0")
|
||||
IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention
|
|||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .loss import cross_entropy_1d
|
||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
|
@ -18,6 +18,7 @@ __all__ = [
|
|||
"DropoutForParallelInput",
|
||||
"DropoutForReplicatedInput",
|
||||
"cross_entropy_1d",
|
||||
"dist_cross_entropy",
|
||||
"BaseLayerNorm",
|
||||
"LayerNorm",
|
||||
"RMSNorm",
|
||||
|
|
|
@ -2,8 +2,11 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.autograd import Function
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
__all__ = ["DistCrossEntropy", "cross_entropy_1d"]
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
|
||||
|
||||
|
||||
class DistCrossEntropy(Function):
|
||||
|
@ -132,3 +135,43 @@ def cross_entropy_1d(
|
|||
dtype: torch.dtype = None,
|
||||
) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
|
||||
|
||||
|
||||
def dist_cross_entropy(
|
||||
labels: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
shard_config: ShardConfig,
|
||||
out_features: int,
|
||||
vocab_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper to compute cross entropy loss for most shardformer models,
|
||||
compatible with PP, TP and SP.
|
||||
"""
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
# Cross entropy with all-reduce for TP
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=out_features,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
# NOTE if use TP and not parallel_output, the output is gathered.
|
||||
# see VocabParallelLMHead1D
|
||||
shift_logits = shift_logits.view(-1, vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
return loss
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .experts import *
|
||||
from .layers import *
|
||||
from .routers import *
|
|
@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
|
|||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_activation
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
|
||||
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
|
||||
|
@ -35,7 +35,7 @@ class MLPExperts(nn.Module):
|
|||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
expert_parallel: Optional[str] = None,
|
||||
expert_parallel: Optional[str] = "EP",
|
||||
activation: Optional[Callable] = None,
|
||||
drop_rate: Optional[float] = 0,
|
||||
gated: Optional[bool] = False,
|
|
@ -8,11 +8,9 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
|
||||
from colossalai.moe.experts import MLPExperts
|
||||
from colossalai.moe.load_balance import LoadBalancer
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.routers import MoeRouter, get_router_cls
|
||||
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
|
||||
from colossalai.shardformer.layer.moe import MLPExperts
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
|
||||
|
||||
|
||||
|
@ -23,6 +21,7 @@ class SparseMLP(nn.Module):
|
|||
dim_model (int): Hidden dimension of training model
|
||||
num_experts (int): The number experts
|
||||
top_k (int, optional): The number of experts for dispatchment of each token
|
||||
parallel (str): parallel mode. Should be "EP", "TP" or None
|
||||
capacity_factor_train (float, optional): Capacity factor in routing during training
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
|
@ -51,6 +50,7 @@ class SparseMLP(nn.Module):
|
|||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
router_top_k: int = 1,
|
||||
parallel: str = "EP",
|
||||
router_loss: bool = True,
|
||||
router_norm: bool = False,
|
||||
router_capacity_factor_train: float = 1.25,
|
||||
|
@ -66,7 +66,7 @@ class SparseMLP(nn.Module):
|
|||
load_balance_group_swap_factor: float = 0.4,
|
||||
enable_kernel: bool = False,
|
||||
enable_comm_overlap: bool = False,
|
||||
enable_hierarchical_comm: bool = False,
|
||||
enable_hierarchical_comm: bool = True,
|
||||
return_gate_logits: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -77,7 +77,9 @@ class SparseMLP(nn.Module):
|
|||
self.return_gate_logits = return_gate_logits
|
||||
self.enable_kernel = enable_kernel
|
||||
self.enable_comm_overlap = enable_comm_overlap
|
||||
self.expert_parallel = MOE_MANAGER.get_parallel()
|
||||
# self.expert_parallel = MOE_MANAGER.get_parallel()
|
||||
assert parallel in ["EP", "TP", None], "parallel mode must be EP, TP or None"
|
||||
self.parallel = parallel
|
||||
self.router_loss = router_loss
|
||||
self.router_norm = router_norm
|
||||
|
||||
|
@ -99,7 +101,7 @@ class SparseMLP(nn.Module):
|
|||
# moe experts
|
||||
self.experts = MLPExperts(
|
||||
num_experts=self.num_experts,
|
||||
expert_parallel=self.expert_parallel,
|
||||
expert_parallel=self.parallel,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
activation=mlp_activation,
|
||||
|
@ -108,11 +110,12 @@ class SparseMLP(nn.Module):
|
|||
)
|
||||
|
||||
# get parallel settings
|
||||
if self.expert_parallel is not None:
|
||||
if self.parallel is not None:
|
||||
self.ep_group = get_ep_group(self.experts)
|
||||
self.ep_size = get_ep_size(self.experts)
|
||||
self.ep_hierarchical_group = None
|
||||
if enable_hierarchical_comm:
|
||||
# TODO: move to plugin
|
||||
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
|
||||
get_ep_group_ranks(self.experts)
|
||||
)
|
||||
|
@ -186,11 +189,11 @@ class SparseMLP(nn.Module):
|
|||
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
|
||||
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
||||
if self.expert_parallel == "EP":
|
||||
if self.parallel == "EP":
|
||||
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel == "TP":
|
||||
elif self.parallel == "TP":
|
||||
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel is None:
|
||||
elif self.parallel is None:
|
||||
expert_output = self._local_process(dispatch_data)
|
||||
else:
|
||||
raise NotImplementedError(
|
|
@ -0,0 +1,161 @@
|
|||
import math
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_activation
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
|
||||
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
|
||||
|
||||
|
||||
class MLPExperts(nn.Module):
|
||||
"""
|
||||
SparseMLP is a multi-layer perceptron with sparse expert parallel layers.
|
||||
|
||||
Args:
|
||||
num_experts (int): The number of experts
|
||||
hidden_size (int): The hidden size of MLP
|
||||
intermediate_size (int): The intermediate size of MLP
|
||||
expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP.
|
||||
activation (optional): The activation function of MLP
|
||||
drop_rate (float, optional): The drop rate of MLP
|
||||
gated (bool, optional): Whether to use gated MLP
|
||||
use_kernel (bool, optional): Whether to use kernel optimization
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
expert_parallel: Optional[str] = "EP",
|
||||
activation: Optional[Callable] = None,
|
||||
drop_rate: Optional[float] = 0,
|
||||
gated: Optional[bool] = False,
|
||||
use_kernel: Optional[bool] = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert expert_parallel in ["EP", "TP", None]
|
||||
self.expert_parallel = expert_parallel
|
||||
self.num_total_experts = num_experts
|
||||
self.gated = gated
|
||||
self.use_kernel = use_kernel
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
# get expert parallel info
|
||||
if expert_parallel is not None:
|
||||
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
|
||||
num_experts, use_tp=True if expert_parallel == "TP" else False
|
||||
)
|
||||
# get settings for different parallel
|
||||
self.ep_size = get_ep_size(self)
|
||||
if expert_parallel == "TP":
|
||||
intermediate_size = intermediate_size // self.ep_size
|
||||
num_experts = self.num_total_experts
|
||||
else:
|
||||
num_experts = self.num_local_experts
|
||||
else:
|
||||
self.num_local_experts = self.num_total_experts
|
||||
self.ep_size = 1
|
||||
|
||||
if gated:
|
||||
self.wi_gate = nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
|
||||
)
|
||||
)
|
||||
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
else:
|
||||
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))
|
||||
|
||||
self.act_name = activation
|
||||
self.act = get_activation(activation)
|
||||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if expert_parallel is not None:
|
||||
for param in self.parameters():
|
||||
set_moe_tensor_info(param, self.moe_info)
|
||||
|
||||
# init param
|
||||
self.reset_parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def reset_parameters(self):
|
||||
# expert param should be different
|
||||
if self.expert_parallel is not None:
|
||||
seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True)
|
||||
else:
|
||||
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
|
||||
with seed_ctx:
|
||||
if self.gated:
|
||||
torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size))
|
||||
torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size))
|
||||
else:
|
||||
torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size))
|
||||
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
param_slice: Tuple[slice] = (slice(None),),
|
||||
use_sparse: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
forward: hidden_size --> intermediate_size --> hidden_size
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
||||
"""
|
||||
x = MoeInGradScaler.apply(x, self.ep_size)
|
||||
|
||||
e = x.size(1)
|
||||
h = x.size(-1)
|
||||
|
||||
x = x.transpose(0, 1)
|
||||
inshape = x.shape
|
||||
x = x.reshape(e, -1, h)
|
||||
|
||||
if self.use_kernel and use_sparse:
|
||||
seq_len = x.shape[1]
|
||||
with torch.no_grad():
|
||||
mask = x[:, :, 0] != 0.0
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
x_list = []
|
||||
for i in range(e):
|
||||
x_list.append(x[i, : mask[i]])
|
||||
x = x_list
|
||||
|
||||
if self.gated:
|
||||
x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)]
|
||||
x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)]
|
||||
if self.use_kernel and HAS_TRITON and self.act_name == "swiglu":
|
||||
x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)]
|
||||
else:
|
||||
x = [self.act(x_gate[i]) * x_up[i] for i in range(e)]
|
||||
else:
|
||||
x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)]
|
||||
x = [self.act(x[i]) for i in range(e)]
|
||||
x = [self.drop(x[i]) for i in range(e)]
|
||||
x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)]
|
||||
|
||||
if self.use_kernel and use_sparse:
|
||||
for i in range(e):
|
||||
x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0)
|
||||
|
||||
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
|
||||
x = x.reshape(inshape)
|
||||
x = x.transpose(0, 1).contiguous()
|
||||
x = MoeOutGradScaler.apply(x, self.ep_size)
|
||||
return x
|
|
@ -1,4 +1,3 @@
|
|||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
@ -1005,115 +1004,6 @@ class BertPipelineForwards:
|
|||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_bert_flash_attention_forward():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.bert.modeling_bert import BertAttention
|
||||
|
||||
def forward(
|
||||
self: BertAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
final_attention_mask = None
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
final_attention_mask = relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
final_attention_mask = relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
scale = 1 / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
if final_attention_mask != None:
|
||||
final_attention_mask = final_attention_mask * scale + attention_mask
|
||||
else:
|
||||
final_attention_mask = attention_mask
|
||||
|
||||
if final_attention_mask is not None:
|
||||
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
|
||||
tgt_len = key_layer.size()[2]
|
||||
final_attention_mask = final_attention_mask.expand(
|
||||
batch_size, self.num_attention_heads, src_len, tgt_len
|
||||
).contiguous()
|
||||
|
||||
query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
|
||||
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
|
||||
value_layer = value_layer.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
context_layer = me_attention(
|
||||
query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale
|
||||
)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, None)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bert_self_output_forward():
|
||||
from transformers.models.bert.modeling_bert import BertSelfOutput
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer import dist_cross_entropy
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -359,30 +359,14 @@ class BloomPipelineForwards:
|
|||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
batch_size, seq_length, vocab_size = shift_logits.shape
|
||||
# Flatten the tokens
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = lm_logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels.view(-1))
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
lm_logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.transformer.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
|
@ -714,93 +698,6 @@ class BloomPipelineForwards:
|
|||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_bloom_flash_attention_forward(enable_jit_fused=False):
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
def forward(
|
||||
self: BloomAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states)
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
batch_size, tgt_len, _, _ = query_layer.size()
|
||||
|
||||
_, kv_length, _, _ = key_layer.size()
|
||||
|
||||
proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
|
||||
query_layer = query_layer.contiguous().view(*proj_shape)
|
||||
key_layer = key_layer.contiguous().view(*proj_shape)
|
||||
value_layer = value_layer.contiguous().view(*proj_shape)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
tgt_len = key_layer.size()[1]
|
||||
|
||||
attention_numerical_mask = torch.zeros(
|
||||
(batch_size, self.num_heads, tgt_len, kv_length),
|
||||
dtype=torch.float32,
|
||||
device=query_layer.device,
|
||||
requires_grad=True,
|
||||
)
|
||||
attention_numerical_mask = (
|
||||
attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
|
||||
)
|
||||
attention_numerical_mask = torch.masked_fill(
|
||||
attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min
|
||||
)
|
||||
attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype)
|
||||
|
||||
context_layer = me_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_bias=attention_numerical_mask,
|
||||
scale=self.inv_norm_factor,
|
||||
p=self.attention_dropout.p,
|
||||
)
|
||||
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = self.hidden_size / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
# TODO to replace with the bias_dropout_add function in jit
|
||||
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
outputs = (output_tensor, present, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bloom_attention_forward():
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
|
@ -1127,24 +1024,10 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
new_vocab_size = lm_logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
|
|
@ -33,6 +33,7 @@ This license shall be governed and construed in accordance with the laws of Peop
|
|||
|
||||
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
|
||||
"""
|
||||
|
||||
""" PyTorch ChatGLM model. """
|
||||
|
||||
import copy
|
||||
|
|
|
@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
|
@ -25,7 +24,7 @@ from colossalai.shardformer.layer._operation import (
|
|||
)
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
|
||||
|
||||
class CommandPipelineForwards:
|
||||
|
@ -300,29 +299,9 @@ class CommandPipelineForwards:
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -658,24 +637,14 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits * self.logit_scale
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.model.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -0,0 +1,429 @@
|
|||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import is_flash_attn_2_available, logging
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
|
||||
|
||||
# copied from modeling_deepseek.py
|
||||
class AddAuxiliaryLoss(torch.autograd.Function):
|
||||
"""
|
||||
The trick function of adding auxiliary (aux) loss,
|
||||
which includes the gradient of the aux loss during backpropagation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, loss):
|
||||
assert loss.numel() == 1
|
||||
ctx.dtype = loss.dtype
|
||||
ctx.required_aux_loss = loss.requires_grad
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_loss = None
|
||||
if ctx.required_aux_loss:
|
||||
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
|
||||
return grad_output, grad_loss
|
||||
|
||||
|
||||
class EPDeepseekMoE(nn.Module):
|
||||
def __init__(self):
|
||||
super(EPDeepseekMoE, self).__init__()
|
||||
|
||||
def setup_ep(self, ep_group: ProcessGroup):
|
||||
ep_group = ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||
self.num_experts = self.config.n_routed_experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
p.ep_group = ep_group
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: Union["DeepseekMoE", "DeepseekMLP"], *args, **kwargs) -> "EPDeepseekMoE":
|
||||
LazyInitContext.materialize(module)
|
||||
if module.__class__.__name__ == "DeepseekMLP":
|
||||
return module
|
||||
module.__class__ = EPDeepseekMoE
|
||||
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
|
||||
module.setup_ep(kwargs["ep_group"])
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
|
||||
topk_experts_idx, topk_experts_weight, aux_loss = self.gate(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # [t0, t1, t2 ...]
|
||||
hidden_states = hidden_states.repeat_interleave(
|
||||
self.num_experts_per_tok, dim=0
|
||||
) # after repeat_interleave: [t0 t0 t1 t1 t2 t2 ... ]
|
||||
|
||||
flat_topk_experts_idx = topk_experts_idx.view(-1) # [e0 e1 e2 ...]
|
||||
# The elements of flat_topk_token_idx are token ids, which are arranged in ascending order of expert ids.
|
||||
flat_topk_token_idx = flat_topk_experts_idx.argsort()
|
||||
|
||||
# Now we adjust the order of the hidden states, also in ascending order of expert id
|
||||
dispatch_states = hidden_states[flat_topk_token_idx]
|
||||
input_split_sizes = flat_topk_experts_idx.bincount(minlength=self.num_experts) # [n0, n1, n2, n3]
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
|
||||
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||
|
||||
if output_states.size(0) > 0:
|
||||
if self.num_experts_per_ep == 1:
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = expert(output_states)
|
||||
else:
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0: # no token routed to this experts
|
||||
continue
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = expert(split_states)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
recover_token_idx = torch.empty_like(flat_topk_token_idx)
|
||||
recover_token_idx[flat_topk_token_idx] = torch.arange(
|
||||
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
|
||||
)
|
||||
|
||||
output_hidden_states = dispatch_states[recover_token_idx] # t0 t0 t1 t1 t2 t2
|
||||
output_hidden_states = output_hidden_states.view(-1, self.num_experts_per_tok, orig_shape[-1])
|
||||
output_hidden_states = (output_hidden_states * topk_experts_weight[:, :, None]).sum(dim=-2) # (B*S, h)
|
||||
output_hidden_states = output_hidden_states.view(*orig_shape)
|
||||
output_hidden_states = AddAuxiliaryLoss.apply(output_hidden_states, aux_loss)
|
||||
if self.config.n_shared_experts is not None:
|
||||
output_hidden_states = output_hidden_states + self.shared_experts(identity)
|
||||
return output_hidden_states
|
||||
|
||||
|
||||
class DeepseekPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def deepseek_model_forward(
|
||||
self: "DeepseekModel",
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if is_flash_attn_2_available():
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = (layer_outputs[2 if output_attentions else 1],)
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
# always return dict for imediate stage
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def deepseek_for_causal_lm_forward(
|
||||
self: "DeepseekForCausalLM",
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MixtralForCausalLM
|
||||
|
||||
>>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = DeepseekPipelineForwards.deepseek_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=outputs[0],
|
||||
attentions=None,
|
||||
)
|
||||
else:
|
||||
out = {}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
out["hidden_states"] = hidden_states
|
||||
return out
|
|
@ -25,7 +25,7 @@ from colossalai.shardformer.layer import ColoAttention
|
|||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer import dist_cross_entropy
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -372,27 +372,9 @@ class GPT2PipelineForwards:
|
|||
|
||||
hidden_states = outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
shift_labels = shift_labels.view(-1)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
else:
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
|
@ -1282,24 +1264,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
loss = dist_cross_entropy(
|
||||
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
|
|
|
@ -31,7 +31,7 @@ from colossalai.shardformer.layer._operation import (
|
|||
)
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
|
||||
|
||||
class LlamaPipelineForwards:
|
||||
|
@ -86,13 +86,20 @@ class LlamaPipelineForwards:
|
|||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
# Support SP + PP
|
||||
sp_mode = shard_config.sequence_parallelism_mode
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
|
||||
# For correct positions ids. The states will be gather along the seq dim in the attention layer later.
|
||||
seq_length *= sp_size
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
|
@ -101,7 +108,7 @@ class LlamaPipelineForwards:
|
|||
if cache_position is None:
|
||||
if isinstance(past_key_values, StaticCache):
|
||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
||||
|
||||
seq_length_with_past = seq_length + past_seen_tokens
|
||||
|
||||
|
@ -118,7 +125,6 @@ class LlamaPipelineForwards:
|
|||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if shard_config.enable_flash_attention:
|
||||
|
@ -134,6 +140,13 @@ class LlamaPipelineForwards:
|
|||
else:
|
||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||
|
||||
# Support SP + PP
|
||||
if stage_manager.is_first_stage():
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
|
@ -196,6 +209,10 @@ class LlamaPipelineForwards:
|
|||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -304,29 +321,9 @@ class LlamaPipelineForwards:
|
|||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -529,7 +526,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
|||
)
|
||||
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
|
@ -804,24 +800,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
|
|
@ -19,7 +19,7 @@ from transformers.utils import logging
|
|||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -275,29 +275,9 @@ class MistralForwards:
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -708,23 +688,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -1,222 +1,108 @@
|
|||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import CrossEntropyLoss, Module
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.models.mixtral.modeling_mixtral import (
|
||||
MixtralDecoderLayer,
|
||||
MixtralForCausalLM,
|
||||
MixtralModel,
|
||||
MixtralSparseMoeBlock,
|
||||
MoeCausalLMOutputWithPast,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.utils import is_flash_attn_2_available, logging
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from .mixtral_layer import EPMixtralSparseMoeBlock
|
||||
|
||||
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
|
||||
|
||||
class MixtralPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, config):
|
||||
self.moe_info = None
|
||||
super().__init__(config)
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
def setup_ep(self, ep_group: ProcessGroup):
|
||||
ep_group = ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
p.ep_group = ep_group
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
@staticmethod
|
||||
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
# if "ep_group" in kwargs:
|
||||
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
|
||||
module.setup_ep(kwargs["ep_group"])
|
||||
return module
|
||||
|
||||
return self.model
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
policy = {}
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
raise NotImplementedError(
|
||||
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
selected_experts = selected_experts.t().reshape(-1)
|
||||
selected_experts_idx = selected_experts.argsort()
|
||||
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
|
||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "MixtralModel":
|
||||
module = self.model
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
# compute expert output
|
||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||
if output_states.size(0) > 0:
|
||||
if self.num_experts_per_ep == 1:
|
||||
# no need to split
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
|
||||
output_states = expert.w2(output_states)
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "MixtralModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
||||
class MixtralModelPolicy(MixtralPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MixtralModel,
|
||||
new_forward=MixtralPipelineForwards.mixtral_model_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
return []
|
||||
|
||||
|
||||
class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
MixtralForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MixtralForCausalLM,
|
||||
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0:
|
||||
continue
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||
split_states = expert.w2(split_states)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||
selected_experts_idx.size(0), device=selected_experts_idx.device
|
||||
)
|
||||
dispatch_states = dispatch_states[recover_experts_idx]
|
||||
k_hidden_states = dispatch_states.chunk(self.top_k)
|
||||
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
|
||||
for i in range(1, self.top_k):
|
||||
output_states += k_hidden_states[i] * routing_weights[:, i, None]
|
||||
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return output_states, router_logits
|
||||
|
||||
|
||||
class MixtralPipelineForwards:
|
||||
|
@ -332,7 +218,7 @@ class MixtralPipelineForwards:
|
|||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if self._use_flash_attention_2:
|
||||
if is_flash_attn_2_available():
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
|
@ -22,7 +22,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer.layer import ColoAttention
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer import dist_cross_entropy
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -221,7 +221,7 @@ class OPTPipelineForwards:
|
|||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if decoder.gradient_checkpointing and decoder.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_outputs = self.decoder._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_attention_mask,
|
||||
|
@ -330,30 +330,14 @@ class OPTPipelineForwards:
|
|||
)
|
||||
if stage_manager.is_last_stage():
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.decoder.dtype,
|
||||
)
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.model.decoder.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -971,26 +955,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
)
|
||||
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.decoder.dtype,
|
||||
)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -32,7 +32,7 @@ from transformers.utils import logging
|
|||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
|
||||
|
||||
class Qwen2PipelineForwards:
|
||||
|
@ -168,13 +168,27 @@ class Qwen2PipelineForwards:
|
|||
next_decoder_cache = None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
num_ckpt_layers = end_idx - start_idx
|
||||
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_stages=stage_manager.num_stages,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if idx - start_idx < num_ckpt_layers:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
|
@ -198,7 +212,6 @@ class Qwen2PipelineForwards:
|
|||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
|
@ -304,25 +317,9 @@ class Qwen2PipelineForwards:
|
|||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -724,26 +721,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue