Merge branch 'hpcaitech:main' into feature/fp8_comm

pull/5885/head
Hanks 2024-07-04 20:34:41 +08:00 committed by GitHub
commit 6991819a97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
63 changed files with 265 additions and 177 deletions

View File

@ -1,34 +1,34 @@
repos: repos:
- repo: https://github.com/PyCQA/autoflake - repo: https://github.com/PyCQA/autoflake
rev: v2.2.1 rev: v2.3.1
hooks: hooks:
- id: autoflake - id: autoflake
name: autoflake (python) name: autoflake (python)
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.12.0 rev: 5.13.2
hooks: hooks:
- id: isort - id: isort
name: sort all imports (python) name: sort all imports (python)
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.9.1 rev: 24.4.2
hooks: hooks:
- id: black - id: black
name: black formatter name: black formatter
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.1 rev: v18.1.8
hooks: hooks:
- id: clang-format - id: clang-format
name: clang formatter name: clang formatter
types_or: [c++, c] types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0 rev: v4.6.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: check-merge-conflict - id: check-merge-conflict

View File

@ -83,15 +83,19 @@ class DataCollatorForSupervisedDataset(object):
# `List[torch.Tensor]` # `List[torch.Tensor]`
batch_input_ids = [ batch_input_ids = [
torch.LongTensor(instance["input_ids"][: self.max_length]) (
if len(instance["input_ids"]) > self.max_length torch.LongTensor(instance["input_ids"][: self.max_length])
else torch.LongTensor(instance["input_ids"]) if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
)
for instance in instances for instance in instances
] ]
batch_labels = [ batch_labels = [
torch.LongTensor(instance["labels"][: self.max_length]) (
if len(instance["labels"]) > self.max_length torch.LongTensor(instance["labels"][: self.max_length])
else torch.LongTensor(instance["labels"]) if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
)
for instance in instances for instance in instances
] ]
if self.tokenizer.padding_side == "right": if self.tokenizer.padding_side == "right":

View File

@ -1,6 +1,7 @@
""" """
loss functions loss functions
""" """
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch

View File

@ -1,6 +1,7 @@
""" """
reward model reward model
""" """
from typing import Optional from typing import Optional
import torch import torch

View File

@ -1,6 +1,7 @@
""" """
Training utilities for Coati. Training utilities for Coati.
""" """
from typing import Any from typing import Any
import torch import torch

View File

@ -78,7 +78,9 @@ def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict
option_string = "ABCDEFG" option_string = "ABCDEFG"
count = len(line["options"]) count = len(line["options"])
input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:" input = (
"问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
)
all_classes = list(option_string[0:count]) all_classes = list(option_string[0:count])
@ -150,7 +152,15 @@ def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=F
) )
elif dataset_name in chinese_qa_datasets: elif dataset_name in chinese_qa_datasets:
question_input = ( question_input = (
"问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label) "问题:"
+ passage
+ " "
+ question
+ "\n"
+ "从以下选项中选择:"
+ " ".join(options)
+ "\n"
+ "答案:{}".format(label)
) )
elif dataset_name in english_cloze_datasets: elif dataset_name in english_cloze_datasets:
question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer) question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)

View File

@ -57,7 +57,11 @@ ceval_subject_mapping = {
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"], "urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
"accountant": ["Accountant", "注册会计师", "Other"], "accountant": ["Accountant", "注册会计师", "Other"],
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"], "fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
"environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"], "environmental_impact_assessment_engineer": [
"Environmental Impact Assessment Engineer",
"环境影响评价工程师",
"Other",
],
"tax_accountant": ["Tax Accountant", "税务师", "Other"], "tax_accountant": ["Tax Accountant", "税务师", "Other"],
"physician": ["Physician", "医师资格", "Other"], "physician": ["Physician", "医师资格", "Other"],
} }

View File

@ -56,9 +56,11 @@ class MTBenchDataset(BaseDataset):
"instruction": question["turns"], "instruction": question["turns"],
"input": "", "input": "",
"output": [], "output": [],
"target": [""] * turn_number "target": (
if question["question_id"] not in reference [""] * turn_number
else reference[question["question_id"]], if question["question_id"] not in reference
else reference[question["question_id"]]
),
} }
if category in dataset["test"]: if category in dataset["test"]:

View File

@ -77,7 +77,9 @@ class HuggingFaceModel(BaseModel):
self.indices_for_choices[0].append( self.indices_for_choices[0].append(
self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1] self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
) )
self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]) self.indices_for_choices[1].append(
self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]
)
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict): def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
""" """

View File

@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license. The original code is licensed under the MIT license.
""" """
from __future__ import annotations from __future__ import annotations
import copy import copy

View File

@ -8,6 +8,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license. The original code is licensed under the MIT license.
""" """
import copy import copy
from typing import Any, Mapping, Optional, Protocol from typing import Any, Mapping, Optional, Protocol

View File

@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license. The original code is licensed under the MIT license.
""" """
import copy import copy
from typing import Any, List from typing import Any, List

View File

@ -2,7 +2,6 @@
Class for loading table type data. please refer to Pandas-Input/Output for file format details. Class for loading table type data. please refer to Pandas-Input/Output for file format details.
""" """
import glob import glob
import os import os

View File

@ -12,6 +12,7 @@ TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True) logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
""" """
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
import torch import torch

View File

@ -1,6 +1,7 @@
""" """
Generation utilities Generation utilities
""" """
import json import json
from typing import List from typing import List

View File

@ -2,6 +2,7 @@
Implement a memory class for storing conversation history Implement a memory class for storing conversation history
Support long term and short term memory Support long term and short term memory
""" """
from typing import Any, Dict, List from typing import Any, Dict, List
from colossalqa.chain.memory.summary import ConversationSummaryMemory from colossalqa.chain.memory.summary import ConversationSummaryMemory

View File

@ -1,6 +1,7 @@
""" """
Class for logging with extra control for debugging Class for logging with extra control for debugging
""" """
import logging import logging

View File

@ -1,6 +1,7 @@
""" """
Script for Chinese retrieval based conversation system backed by ChatGLM Script for Chinese retrieval based conversation system backed by ChatGLM
""" """
from typing import Tuple from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA from colossalqa.chain.retrieval_qa.base import RetrievalQA

View File

@ -1,6 +1,7 @@
""" """
Multilingual retrieval based conversation system Multilingual retrieval based conversation system
""" """
from typing import List from typing import List
from colossalqa.data_loader.document_loader import DocumentLoader from colossalqa.data_loader.document_loader import DocumentLoader

View File

@ -1,6 +1,7 @@
""" """
Script for Chinese retrieval based conversation system backed by ChatGLM Script for Chinese retrieval based conversation system backed by ChatGLM
""" """
from typing import Tuple from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA from colossalqa.chain.retrieval_qa.base import RetrievalQA

View File

@ -1,6 +1,7 @@
""" """
Code for custom retriver with incremental update Code for custom retriver with incremental update
""" """
import copy import copy
import hashlib import hashlib
import os import os

View File

@ -1,6 +1,7 @@
""" """
Code for Chinese text splitter Code for Chinese text splitter
""" """
from typing import Any, List, Optional from typing import Any, List, Optional
from colossalqa.text_splitter.utils import get_cleaned_paragraph from colossalqa.text_splitter.utils import get_cleaned_paragraph

View File

@ -1,6 +1,7 @@
""" """
Script for English retrieval based conversation system backed by LLaMa2 Script for English retrieval based conversation system backed by LLaMa2
""" """
import argparse import argparse
import os import os

View File

@ -1,6 +1,7 @@
""" """
Script for English retrieval based conversation system backed by LLaMa2 Script for English retrieval based conversation system backed by LLaMa2
""" """
import argparse import argparse
import json import json
import os import os

View File

@ -1,6 +1,7 @@
""" """
Script for Chinese retrieval based conversation system backed by ChatGLM Script for Chinese retrieval based conversation system backed by ChatGLM
""" """
import argparse import argparse
import os import os

View File

@ -1,6 +1,7 @@
""" """
Script for English retrieval based conversation system backed by LLaMa2 Script for English retrieval based conversation system backed by LLaMa2
""" """
import argparse import argparse
import os import os

View File

@ -107,20 +107,22 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost( fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]), activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) parameter=(
if has_bias compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
else compute_size_in_bytes(weight_tensor), ),
temp=0, temp=0,
buffer=0, buffer=0,
) )
bwd_memory_cost = MemoryCost( bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) activation=(
if has_bias compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
else compute_size_in_bytes([input_tensor, weight_tensor]), if has_bias
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) else compute_size_in_bytes([input_tensor, weight_tensor])
if has_bias ),
else compute_size_in_bytes(weight_tensor), parameter=(
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
),
temp=0, temp=0,
buffer=0, buffer=0,
) )

View File

@ -1,10 +1,18 @@
from .gemini_plugin import GeminiPlugin from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin from .low_level_zero_plugin import LowLevelZeroPlugin
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from .plugin_base import Plugin from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] __all__ = [
"Plugin",
"TorchDDPPlugin",
"GeminiPlugin",
"LowLevelZeroPlugin",
"HybridParallelPlugin",
"MoeHybridParallelPlugin",
]
import torch import torch
from packaging import version from packaging import version

View File

@ -247,16 +247,16 @@ class BatchBucket:
self._sequences_dict[seq.request_id] = seq self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size + i self._sequences_indexes[seq.request_id] = self._current_batch_size + i
# TODO external (rename): modify Sequence.sentence_len to seq_len # TODO external (rename): modify Sequence.sentence_len to seq_len
self._sequence_lengths[ self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
self._current_batch_size : self._current_batch_size + num_seqs_to_add torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) )
# NOTE block tables to be updated by kvcache manager # NOTE block tables to be updated by kvcache manager
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
if alloc_block_tables is not None: if alloc_block_tables is not None:
# copy block ids from provided block tables # copy block ids from provided block tables
self._block_tables[ self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
self._current_batch_size : self._current_batch_size + num_seqs_to_add alloc_block_tables
] = alloc_block_tables )
elif alloc_block_tables_fn: elif alloc_block_tables_fn:
alloc_block_tables_fn( alloc_block_tables_fn(
block_tables, block_tables,

View File

@ -1,6 +1,7 @@
""" """
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
""" """
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
@ -82,9 +83,9 @@ class InputMetaData(RPC_PARAM):
dtype: torch.dtype = torch.float32 dtype: torch.dtype = torch.float32
use_spec_dec: bool = False use_spec_dec: bool = False
num_tokens_to_verify: int = 0 num_tokens_to_verify: int = 0
batch_token_ids: Optional[ batch_token_ids: Optional[List[List[int]]] = (
List[List[int]] None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process )
def to_rpc_param(self) -> Dict[str, any]: def to_rpc_param(self) -> Dict[str, any]:
return { return {
@ -202,9 +203,9 @@ class InferenceConfig(RPC_PARAM):
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
do_sample: bool = False do_sample: bool = False
beam_width: int = 1 # TODO: beam search is not support for now beam_width: int = 1 # TODO: beam search is not support for now
prefill_ratio: Optional[ prefill_ratio: Optional[float] = (
float 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio )
pad_input: bool = False pad_input: bool = False
early_stopping: Optional[bool] = False early_stopping: Optional[bool] = False
top_k: Optional[int] = 50 top_k: Optional[int] = 50
@ -234,7 +235,9 @@ class InferenceConfig(RPC_PARAM):
high_precision: Optional[bool] = False high_precision: Optional[bool] = False
# cuda_graph # cuda_graph
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference use_cuda_graph: bool = (
False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
)
max_context_len_to_capture: int = 512 max_context_len_to_capture: int = 512
# StreamingLLM (sliding window attention with attention sinks) # StreamingLLM (sliding window attention with attention sinks)

View File

@ -47,7 +47,6 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class InferenceEngine: class InferenceEngine:
""" """
InferenceEngine which manages the inference process.. InferenceEngine which manages the inference process..

View File

@ -34,7 +34,6 @@ def run_server(host, port, event: mp.Event = None):
class RPCInferenceEngine(InferenceEngine): class RPCInferenceEngine(InferenceEngine):
""" """
InferenceEngine which manages the inference process.. InferenceEngine which manages the inference process..

View File

@ -42,7 +42,6 @@ logger = get_dist_logger(__name__)
class rpcWorkerService(rpyc.Service): class rpcWorkerService(rpyc.Service):
""" """
Execute the computation tasks and manage its own kv cache Execute the computation tasks and manage its own kv cache

View File

@ -279,9 +279,11 @@ class KVCacheManager:
block.add_ref() block.add_ref()
self._allocate_on_block( self._allocate_on_block(
block, block,
block.block_size (
if context_lengths[i] % block.block_size == 0 block.block_size
else context_lengths[i].item() % block.block_size, if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size
),
) )
for block_id in alloc_block_ids: for block_id in alloc_block_ids:
if block_id in alloc_block_ids[last_block_locs]: if block_id in alloc_block_ids[last_block_locs]:

View File

@ -1,6 +1,7 @@
""" """
Utils for model inference Utils for model inference
""" """
import math import math
import os import os
import re import re

View File

@ -138,9 +138,7 @@ class Initializer_2D(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.summa_dim = int(math.sqrt(self.tensor_parallel_size)) self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
assert ( assert self.tensor_parallel_size == self.summa_dim**2, "2D summa dim should equal to tensor parallel size ^ 0.5"
self.tensor_parallel_size == self.summa_dim**2
), "2D summa dim should equal to tensor parallel size ^ 0.5"
_check_summa_env_var(self.summa_dim) _check_summa_env_var(self.summa_dim)
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs) self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)

View File

@ -54,7 +54,6 @@ class RequestTracker:
class Async_Engine: class Async_Engine:
""" """
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
Background loop: inference reqs in waiting list (Listen) Background loop: inference reqs in waiting list (Listen)

View File

@ -118,16 +118,16 @@ class Batch:
class BatchTokenIdOut: class BatchTokenIdOut:
def __init__(self): def __init__(self):
self.reqs_infs: List[ self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = (
Tuple[str, int, Dict, bool, bool] []
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] ) # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
class BatchStrOut: class BatchStrOut:
def __init__(self): def __init__(self):
self.reqs_infs: List[ self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = (
Tuple[str, str, Dict, bool, bool] []
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] ) # [req_id, token_str, gen_metadata, finished_state, abort_state]
class AbortReq: class AbortReq:

View File

@ -1,6 +1,7 @@
""" """
Utils for model inference Utils for model inference
""" """
import os import os
import torch import torch

View File

@ -14,6 +14,7 @@ class BatchInferState:
Information to be passed and used for a batch of inputs during Information to be passed and used for a batch of inputs during
a single model forward a single model forward
""" """
batch_size: int batch_size: int
max_len_in_batch: int max_len_in_batch: int

View File

@ -4,6 +4,7 @@ of the ModelTC/lightllm GitHub repository
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design. we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
""" """
import torch import torch
from transformers.utils import logging from transformers.utils import logging

View File

@ -1,6 +1,7 @@
""" """
Utils for model inference Utils for model inference
""" """
import os import os
import torch import torch

View File

@ -1,17 +1,25 @@
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py # adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
import importlib.metadata
import logging import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from packaging.version import Version
from .bnb_config import BnbQuantizationConfig from .bnb_config import BnbQuantizationConfig
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0" try:
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2" # in case lower version of bitsandbytes does not have __version__ attribute
BNB_VERSION = Version(bnb.__version__)
except AttributeError:
BNB_VERSION = Version(importlib.metadata.version("bitsandbytes"))
IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0")
IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2")
except ImportError: except ImportError:
pass pass

View File

@ -33,6 +33,7 @@ This license shall be governed and construed in accordance with the laws of Peop
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
""" """
""" PyTorch ChatGLM model. """ """ PyTorch ChatGLM model. """
import copy import copy

View File

@ -221,7 +221,7 @@ class OPTPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training: if decoder.gradient_checkpointing and decoder.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self.decoder._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
causal_attention_mask, causal_attention_mask,

View File

@ -168,13 +168,27 @@ class Qwen2PipelineForwards:
next_decoder_cache = None next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
@ -198,7 +212,6 @@ class Qwen2PipelineForwards:
if use_cache: if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)

View File

@ -40,21 +40,19 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
if getattr(self.shard_config, "ep_group", None) is None: if getattr(self.shard_config, "ep_group", None) is not None:
raise ValueError("You must pass in ep_group via shard_config for expert parallel!") # expert parallel
self.append_or_create_submodule_replacement(
# expert parallel description=[
self.append_or_create_submodule_replacement( SubModuleReplacementDescription(
description=[ suffix="block_sparse_moe",
SubModuleReplacementDescription( target_module=EPMixtralSparseMoeBlock,
suffix="block_sparse_moe", kwargs={"ep_group": self.shard_config.ep_group},
target_module=EPMixtralSparseMoeBlock, )
kwargs={"ep_group": self.shard_config.ep_group}, ],
) policy=policy,
], target_key=MixtralDecoderLayer,
policy=policy, )
target_key=MixtralDecoderLayer,
)
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:

View File

@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = real_working_params[group_id][idx] working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype) param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param] pg = self.param_to_pg[working_param]
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue
try: try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError: except RuntimeError:

View File

@ -87,44 +87,42 @@ optim = DistGaloreAwamW(
## Plugin compatibility ## Plugin compatibility
<table> <table>
<tr> <tr>
<th nowrap="nowrap">Model/Feature</th> <th nowrap="nowrap">Optimizer/Plugin</th>
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th> <th nowrap="nowrap" align="center">Hybrid Parallel Plugin</th>
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th> <th nowrap="nowrap" align="center">Low Level Zero Plugin</th>
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th> <th nowrap="nowrap" align="center">Torch DDP Plugin</th>
<th nowrap="nowrap" align="center" title="CAME">CAME</th> <th nowrap="nowrap" align="center">Gemini Plugin</th>
<th nowrap="nowrap" align="center">Moe Hybrid Plugin</th>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td> <td nowrap="nowrap" align="center" title="Lamb">Lamb</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Gemini<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td> <td nowrap="nowrap" align="center" title="GaLore">GaLore</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="Adafactor">Adafactor</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="CAME">CAME</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>

View File

@ -55,7 +55,7 @@ Model/Feature Compatibility Matrix:
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>

View File

@ -84,44 +84,42 @@ optim = DistGaloreAwamW(
## 兼容性 ## 兼容性
<table> <table>
<tr> <tr>
<th nowrap="nowrap">Model/Feature</th> <th nowrap="nowrap">Optimizer/Plugin</th>
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th> <th nowrap="nowrap" align="center">Hybrid Parallel Plugin</th>
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th> <th nowrap="nowrap" align="center">Low Level Zero Plugin</th>
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th> <th nowrap="nowrap" align="center">Torch DDP Plugin</th>
<th nowrap="nowrap" align="center" title="CAME">CAME</th> <th nowrap="nowrap" align="center">Gemini Plugin</th>
<th nowrap="nowrap" align="center">Moe Hybrid Plugin</th>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td> <td nowrap="nowrap" align="center" title="Lamb">Lamb</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Gemini<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td> <td nowrap="nowrap" align="center" title="GaLore">GaLore</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="Adafactor">Adafactor</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="CAME">CAME</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
@ -130,6 +128,7 @@ optim = DistGaloreAwamW(
</tr> </tr>
</table> </table>
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py --> <!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
## API 参考 ## API 参考

View File

@ -51,7 +51,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>

View File

@ -52,9 +52,11 @@ class pretraining_dataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
[input_ids, input_mask, segment_ids, masked_lm_labels] = [ [input_ids, input_mask, segment_ids, masked_lm_labels] = [
torch.from_numpy(input[index].astype(np.int64)) (
if indice < 5 torch.from_numpy(input[index].astype(np.int64))
else torch.from_numpy(np.asarray(input[index].astype(np.int64))) if indice < 5
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
)
for indice, input in enumerate(self.inputs) for indice, input in enumerate(self.inputs)
] ]

View File

@ -229,9 +229,7 @@ class DDPM(pl.LightningModule):
) )
if self.parameterization == "eps": if self.parameterization == "eps":
lvlb_weights = self.betas**2 / ( lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)
)
elif self.parameterization == "x0": elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
elif self.parameterization == "v": elif self.parameterization == "v":
@ -1186,9 +1184,11 @@ class LatentDiffusion(DDPM):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = { cond = {
key: cond[key][:batch_size] key: (
if not isinstance(cond[key], list) cond[key][:batch_size]
else list(map(lambda x: x[:batch_size], cond[key])) if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
)
for key in cond for key in cond
} }
else: else:
@ -1321,9 +1321,11 @@ class LatentDiffusion(DDPM):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = { cond = {
key: cond[key][:batch_size] key: (
if not isinstance(cond[key], list) cond[key][:batch_size]
else list(map(lambda x: x[:batch_size], cond[key])) if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
)
for key in cond for key in cond
} }
else: else:

View File

@ -1,4 +1,5 @@
"""SAMPLING ONLY.""" """SAMPLING ONLY."""
import torch import torch
from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper

View File

@ -640,23 +640,25 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
), ),
AttentionBlock( (
ch, AttentionBlock(
use_checkpoint=use_checkpoint, ch,
num_heads=num_heads, use_checkpoint=use_checkpoint,
num_head_channels=dim_head, num_heads=num_heads,
use_new_attention_order=use_new_attention_order, num_head_channels=dim_head,
) use_new_attention_order=use_new_attention_order,
if not use_spatial_transformer )
else SpatialTransformer( # always uses a self-attn if not use_spatial_transformer
ch, else SpatialTransformer( # always uses a self-attn
num_heads, ch,
dim_head, num_heads,
depth=transformer_depth, dim_head,
context_dim=context_dim, depth=transformer_depth,
disable_self_attn=disable_middle_self_attn, context_dim=context_dim,
use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn,
use_checkpoint=use_checkpoint, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
)
), ),
ResBlock( ResBlock(
ch, ch,

View File

@ -2,6 +2,7 @@
This file contains code that is adapted from This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -2,6 +2,7 @@
This file contains code that is adapted from This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -1,4 +1,5 @@
"""Utils for monoDepth.""" """Utils for monoDepth."""
import re import re
import sys import sys

View File

@ -369,9 +369,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
} }
} // for (auto sent_index=sent_index_first; ... } // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) { } // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) { } // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) { } // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) { if (!second) {
if (verbose) { if (verbose) {
@ -606,9 +606,9 @@ py::array build_blocks_mapping_impl(
num_sent = 0; num_sent = 0;
} }
} // for (auto sent_index=sent_index_first; ... } // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) { } // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) { } // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) { } // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) { if (!second) {
if (verbose) { if (verbose) {

View File

@ -4,7 +4,7 @@
#include <cmath> #include <cmath>
#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) #define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#define TILE (128 * 1024 * 1024) #define TILE (128 * 1024 * 1024)
#if defined(__aarch64__) #if defined(__aarch64__)

View File

@ -32,7 +32,7 @@ SOFTWARE
#include <x86intrin.h> #include <x86intrin.h>
#endif #endif
#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) #define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#define TILE (128 * 1024 * 1024) #define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)

View File

@ -34,14 +34,14 @@ def swin_s():
# special output transform fn # special output transform fn
google_net_output_transform_fn = ( google_net_output_transform_fn = lambda x: (
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x) dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
) )
swin_s_output_output_transform_fn = ( swin_s_output_output_transform_fn = lambda x: (
lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
) )
inception_v3_output_transform_fn = ( inception_v3_output_transform_fn = lambda x: (
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x) dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
) )
model_zoo.register( model_zoo.register(