Merge remote-tracking branch 'origin/feature/fp8_comm' into feature/fp8_comm

# Conflicts:
#	colossalai/quantization/fp8.py
pull/5885/head
BurkeHulk 5 months ago
commit 1f1b856354

@ -1,34 +1,34 @@
repos: repos:
- repo: https://github.com/PyCQA/autoflake - repo: https://github.com/PyCQA/autoflake
rev: v2.2.1 rev: v2.3.1
hooks: hooks:
- id: autoflake - id: autoflake
name: autoflake (python) name: autoflake (python)
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.12.0 rev: 5.13.2
hooks: hooks:
- id: isort - id: isort
name: sort all imports (python) name: sort all imports (python)
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.9.1 rev: 24.4.2
hooks: hooks:
- id: black - id: black
name: black formatter name: black formatter
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.1 rev: v18.1.8
hooks: hooks:
- id: clang-format - id: clang-format
name: clang formatter name: clang formatter
types_or: [c++, c] types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0 rev: v4.6.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: check-merge-conflict - id: check-merge-conflict

@ -83,15 +83,19 @@ class DataCollatorForSupervisedDataset(object):
# `List[torch.Tensor]` # `List[torch.Tensor]`
batch_input_ids = [ batch_input_ids = [
(
torch.LongTensor(instance["input_ids"][: self.max_length]) torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"]) else torch.LongTensor(instance["input_ids"])
)
for instance in instances for instance in instances
] ]
batch_labels = [ batch_labels = [
(
torch.LongTensor(instance["labels"][: self.max_length]) torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"]) else torch.LongTensor(instance["labels"])
)
for instance in instances for instance in instances
] ]
if self.tokenizer.padding_side == "right": if self.tokenizer.padding_side == "right":

@ -1,6 +1,7 @@
""" """
loss functions loss functions
""" """
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch

@ -1,6 +1,7 @@
""" """
reward model reward model
""" """
from typing import Optional from typing import Optional
import torch import torch

@ -1,6 +1,7 @@
""" """
Training utilities for Coati. Training utilities for Coati.
""" """
from typing import Any from typing import Any
import torch import torch

@ -78,7 +78,9 @@ def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict
option_string = "ABCDEFG" option_string = "ABCDEFG"
count = len(line["options"]) count = len(line["options"])
input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:" input = (
"问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
)
all_classes = list(option_string[0:count]) all_classes = list(option_string[0:count])
@ -150,7 +152,15 @@ def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=F
) )
elif dataset_name in chinese_qa_datasets: elif dataset_name in chinese_qa_datasets:
question_input = ( question_input = (
"问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label) "问题:"
+ passage
+ " "
+ question
+ "\n"
+ "从以下选项中选择:"
+ " ".join(options)
+ "\n"
+ "答案:{}".format(label)
) )
elif dataset_name in english_cloze_datasets: elif dataset_name in english_cloze_datasets:
question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer) question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)

@ -57,7 +57,11 @@ ceval_subject_mapping = {
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"], "urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
"accountant": ["Accountant", "注册会计师", "Other"], "accountant": ["Accountant", "注册会计师", "Other"],
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"], "fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
"environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"], "environmental_impact_assessment_engineer": [
"Environmental Impact Assessment Engineer",
"环境影响评价工程师",
"Other",
],
"tax_accountant": ["Tax Accountant", "税务师", "Other"], "tax_accountant": ["Tax Accountant", "税务师", "Other"],
"physician": ["Physician", "医师资格", "Other"], "physician": ["Physician", "医师资格", "Other"],
} }

@ -56,9 +56,11 @@ class MTBenchDataset(BaseDataset):
"instruction": question["turns"], "instruction": question["turns"],
"input": "", "input": "",
"output": [], "output": [],
"target": [""] * turn_number "target": (
[""] * turn_number
if question["question_id"] not in reference if question["question_id"] not in reference
else reference[question["question_id"]], else reference[question["question_id"]]
),
} }
if category in dataset["test"]: if category in dataset["test"]:

@ -77,7 +77,9 @@ class HuggingFaceModel(BaseModel):
self.indices_for_choices[0].append( self.indices_for_choices[0].append(
self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1] self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
) )
self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]) self.indices_for_choices[1].append(
self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]
)
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict): def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
""" """

@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license. The original code is licensed under the MIT license.
""" """
from __future__ import annotations from __future__ import annotations
import copy import copy

@ -8,6 +8,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license. The original code is licensed under the MIT license.
""" """
import copy import copy
from typing import Any, Mapping, Optional, Protocol from typing import Any, Mapping, Optional, Protocol

@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license. The original code is licensed under the MIT license.
""" """
import copy import copy
from typing import Any, List from typing import Any, List

@ -2,7 +2,6 @@
Class for loading table type data. please refer to Pandas-Input/Output for file format details. Class for loading table type data. please refer to Pandas-Input/Output for file format details.
""" """
import glob import glob
import os import os

@ -12,6 +12,7 @@ TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True) logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
""" """
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
import torch import torch

@ -1,6 +1,7 @@
""" """
Generation utilities Generation utilities
""" """
import json import json
from typing import List from typing import List

@ -2,6 +2,7 @@
Implement a memory class for storing conversation history Implement a memory class for storing conversation history
Support long term and short term memory Support long term and short term memory
""" """
from typing import Any, Dict, List from typing import Any, Dict, List
from colossalqa.chain.memory.summary import ConversationSummaryMemory from colossalqa.chain.memory.summary import ConversationSummaryMemory

@ -1,6 +1,7 @@
""" """
Class for logging with extra control for debugging Class for logging with extra control for debugging
""" """
import logging import logging

@ -1,6 +1,7 @@
""" """
Script for Chinese retrieval based conversation system backed by ChatGLM Script for Chinese retrieval based conversation system backed by ChatGLM
""" """
from typing import Tuple from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA from colossalqa.chain.retrieval_qa.base import RetrievalQA

@ -1,6 +1,7 @@
""" """
Multilingual retrieval based conversation system Multilingual retrieval based conversation system
""" """
from typing import List from typing import List
from colossalqa.data_loader.document_loader import DocumentLoader from colossalqa.data_loader.document_loader import DocumentLoader

@ -1,6 +1,7 @@
""" """
Script for Chinese retrieval based conversation system backed by ChatGLM Script for Chinese retrieval based conversation system backed by ChatGLM
""" """
from typing import Tuple from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA from colossalqa.chain.retrieval_qa.base import RetrievalQA

@ -1,6 +1,7 @@
""" """
Code for custom retriver with incremental update Code for custom retriver with incremental update
""" """
import copy import copy
import hashlib import hashlib
import os import os

@ -1,6 +1,7 @@
""" """
Code for Chinese text splitter Code for Chinese text splitter
""" """
from typing import Any, List, Optional from typing import Any, List, Optional
from colossalqa.text_splitter.utils import get_cleaned_paragraph from colossalqa.text_splitter.utils import get_cleaned_paragraph

@ -1,6 +1,7 @@
""" """
Script for English retrieval based conversation system backed by LLaMa2 Script for English retrieval based conversation system backed by LLaMa2
""" """
import argparse import argparse
import os import os

@ -1,6 +1,7 @@
""" """
Script for English retrieval based conversation system backed by LLaMa2 Script for English retrieval based conversation system backed by LLaMa2
""" """
import argparse import argparse
import json import json
import os import os

@ -1,6 +1,7 @@
""" """
Script for Chinese retrieval based conversation system backed by ChatGLM Script for Chinese retrieval based conversation system backed by ChatGLM
""" """
import argparse import argparse
import os import os

@ -1,6 +1,7 @@
""" """
Script for English retrieval based conversation system backed by LLaMa2 Script for English retrieval based conversation system backed by LLaMa2
""" """
import argparse import argparse
import os import os

@ -107,20 +107,22 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost( fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]), activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) parameter=(
if has_bias compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
else compute_size_in_bytes(weight_tensor), ),
temp=0, temp=0,
buffer=0, buffer=0,
) )
bwd_memory_cost = MemoryCost( bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) activation=(
if has_bias compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias if has_bias
else compute_size_in_bytes(weight_tensor), else compute_size_in_bytes([input_tensor, weight_tensor])
),
parameter=(
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
),
temp=0, temp=0,
buffer=0, buffer=0,
) )

@ -1,10 +1,18 @@
from .gemini_plugin import GeminiPlugin from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin from .low_level_zero_plugin import LowLevelZeroPlugin
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from .plugin_base import Plugin from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] __all__ = [
"Plugin",
"TorchDDPPlugin",
"GeminiPlugin",
"LowLevelZeroPlugin",
"HybridParallelPlugin",
"MoeHybridParallelPlugin",
]
import torch import torch
from packaging import version from packaging import version

@ -247,16 +247,16 @@ class BatchBucket:
self._sequences_dict[seq.request_id] = seq self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size + i self._sequences_indexes[seq.request_id] = self._current_batch_size + i
# TODO external (rename): modify Sequence.sentence_len to seq_len # TODO external (rename): modify Sequence.sentence_len to seq_len
self._sequence_lengths[ self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
self._current_batch_size : self._current_batch_size + num_seqs_to_add torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) )
# NOTE block tables to be updated by kvcache manager # NOTE block tables to be updated by kvcache manager
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
if alloc_block_tables is not None: if alloc_block_tables is not None:
# copy block ids from provided block tables # copy block ids from provided block tables
self._block_tables[ self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
self._current_batch_size : self._current_batch_size + num_seqs_to_add alloc_block_tables
] = alloc_block_tables )
elif alloc_block_tables_fn: elif alloc_block_tables_fn:
alloc_block_tables_fn( alloc_block_tables_fn(
block_tables, block_tables,

@ -1,6 +1,7 @@
""" """
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
""" """
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
@ -82,9 +83,9 @@ class InputMetaData(RPC_PARAM):
dtype: torch.dtype = torch.float32 dtype: torch.dtype = torch.float32
use_spec_dec: bool = False use_spec_dec: bool = False
num_tokens_to_verify: int = 0 num_tokens_to_verify: int = 0
batch_token_ids: Optional[ batch_token_ids: Optional[List[List[int]]] = (
List[List[int]] None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process )
def to_rpc_param(self) -> Dict[str, any]: def to_rpc_param(self) -> Dict[str, any]:
return { return {
@ -202,9 +203,9 @@ class InferenceConfig(RPC_PARAM):
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
do_sample: bool = False do_sample: bool = False
beam_width: int = 1 # TODO: beam search is not support for now beam_width: int = 1 # TODO: beam search is not support for now
prefill_ratio: Optional[ prefill_ratio: Optional[float] = (
float 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio )
pad_input: bool = False pad_input: bool = False
early_stopping: Optional[bool] = False early_stopping: Optional[bool] = False
top_k: Optional[int] = 50 top_k: Optional[int] = 50
@ -234,7 +235,9 @@ class InferenceConfig(RPC_PARAM):
high_precision: Optional[bool] = False high_precision: Optional[bool] = False
# cuda_graph # cuda_graph
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference use_cuda_graph: bool = (
False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
)
max_context_len_to_capture: int = 512 max_context_len_to_capture: int = 512
# StreamingLLM (sliding window attention with attention sinks) # StreamingLLM (sliding window attention with attention sinks)

@ -47,7 +47,6 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class InferenceEngine: class InferenceEngine:
""" """
InferenceEngine which manages the inference process.. InferenceEngine which manages the inference process..

@ -34,7 +34,6 @@ def run_server(host, port, event: mp.Event = None):
class RPCInferenceEngine(InferenceEngine): class RPCInferenceEngine(InferenceEngine):
""" """
InferenceEngine which manages the inference process.. InferenceEngine which manages the inference process..

@ -42,7 +42,6 @@ logger = get_dist_logger(__name__)
class rpcWorkerService(rpyc.Service): class rpcWorkerService(rpyc.Service):
""" """
Execute the computation tasks and manage its own kv cache Execute the computation tasks and manage its own kv cache

@ -279,9 +279,11 @@ class KVCacheManager:
block.add_ref() block.add_ref()
self._allocate_on_block( self._allocate_on_block(
block, block,
(
block.block_size block.block_size
if context_lengths[i] % block.block_size == 0 if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size, else context_lengths[i].item() % block.block_size
),
) )
for block_id in alloc_block_ids: for block_id in alloc_block_ids:
if block_id in alloc_block_ids[last_block_locs]: if block_id in alloc_block_ids[last_block_locs]:

@ -1,6 +1,7 @@
""" """
Utils for model inference Utils for model inference
""" """
import math import math
import os import os
import re import re

@ -138,9 +138,7 @@ class Initializer_2D(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.summa_dim = int(math.sqrt(self.tensor_parallel_size)) self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
assert ( assert self.tensor_parallel_size == self.summa_dim**2, "2D summa dim should equal to tensor parallel size ^ 0.5"
self.tensor_parallel_size == self.summa_dim**2
), "2D summa dim should equal to tensor parallel size ^ 0.5"
_check_summa_env_var(self.summa_dim) _check_summa_env_var(self.summa_dim)
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs) self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)

@ -54,7 +54,6 @@ class RequestTracker:
class Async_Engine: class Async_Engine:
""" """
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
Background loop: inference reqs in waiting list (Listen) Background loop: inference reqs in waiting list (Listen)

@ -118,16 +118,16 @@ class Batch:
class BatchTokenIdOut: class BatchTokenIdOut:
def __init__(self): def __init__(self):
self.reqs_infs: List[ self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = (
Tuple[str, int, Dict, bool, bool] []
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] ) # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
class BatchStrOut: class BatchStrOut:
def __init__(self): def __init__(self):
self.reqs_infs: List[ self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = (
Tuple[str, str, Dict, bool, bool] []
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] ) # [req_id, token_str, gen_metadata, finished_state, abort_state]
class AbortReq: class AbortReq:

@ -1,6 +1,7 @@
""" """
Utils for model inference Utils for model inference
""" """
import os import os
import torch import torch

@ -14,6 +14,7 @@ class BatchInferState:
Information to be passed and used for a batch of inputs during Information to be passed and used for a batch of inputs during
a single model forward a single model forward
""" """
batch_size: int batch_size: int
max_len_in_batch: int max_len_in_batch: int

@ -4,6 +4,7 @@ of the ModelTC/lightllm GitHub repository
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design. we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
""" """
import torch import torch
from transformers.utils import logging from transformers.utils import logging

@ -1,6 +1,7 @@
""" """
Utils for model inference Utils for model inference
""" """
import os import os
import torch import torch

@ -1,17 +1,25 @@
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py # adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
import importlib.metadata
import logging import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from packaging.version import Version
from .bnb_config import BnbQuantizationConfig from .bnb_config import BnbQuantizationConfig
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0" try:
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2" # in case lower version of bitsandbytes does not have __version__ attribute
BNB_VERSION = Version(bnb.__version__)
except AttributeError:
BNB_VERSION = Version(importlib.metadata.version("bitsandbytes"))
IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0")
IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2")
except ImportError: except ImportError:
pass pass

@ -33,6 +33,7 @@ This license shall be governed and construed in accordance with the laws of Peop
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
""" """
""" PyTorch ChatGLM model. """ """ PyTorch ChatGLM model. """
import copy import copy

@ -221,7 +221,7 @@ class OPTPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training: if decoder.gradient_checkpointing and decoder.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self.decoder._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
causal_attention_mask, causal_attention_mask,

@ -168,13 +168,27 @@ class Qwen2PipelineForwards:
next_decoder_cache = None next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
@ -198,7 +212,6 @@ class Qwen2PipelineForwards:
if use_cache: if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)

@ -40,9 +40,7 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
if getattr(self.shard_config, "ep_group", None) is None: if getattr(self.shard_config, "ep_group", None) is not None:
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
# expert parallel # expert parallel
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[

@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = real_working_params[group_id][idx] working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype) param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param] pg = self.param_to_pg[working_param]
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue
try: try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError: except RuntimeError:

@ -87,44 +87,42 @@ optim = DistGaloreAwamW(
## Plugin compatibility ## Plugin compatibility
<table> <table>
<tr> <tr>
<th nowrap="nowrap">Model/Feature</th> <th nowrap="nowrap">Optimizer/Plugin</th>
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th> <th nowrap="nowrap" align="center">Hybrid Parallel Plugin</th>
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th> <th nowrap="nowrap" align="center">Low Level Zero Plugin</th>
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th> <th nowrap="nowrap" align="center">Torch DDP Plugin</th>
<th nowrap="nowrap" align="center" title="CAME">CAME</th> <th nowrap="nowrap" align="center">Gemini Plugin</th>
<th nowrap="nowrap" align="center">Moe Hybrid Plugin</th>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td> <td nowrap="nowrap" align="center" title="Lamb">Lamb</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Low Level Zero<br />Plugin</td> <td nowrap="nowrap" align="center" title="GaLore">GaLore</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Torch DDP<br />Plugin</td> <td nowrap="nowrap" align="center" title="Adafactor">Adafactor</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Gemini<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td> <td nowrap="nowrap" align="center" title="CAME">CAME</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>

@ -55,7 +55,7 @@ Model/Feature Compatibility Matrix:
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>

@ -84,44 +84,42 @@ optim = DistGaloreAwamW(
## 兼容性 ## 兼容性
<table> <table>
<tr> <tr>
<th nowrap="nowrap">Model/Feature</th> <th nowrap="nowrap">Optimizer/Plugin</th>
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th> <th nowrap="nowrap" align="center">Hybrid Parallel Plugin</th>
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th> <th nowrap="nowrap" align="center">Low Level Zero Plugin</th>
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th> <th nowrap="nowrap" align="center">Torch DDP Plugin</th>
<th nowrap="nowrap" align="center" title="CAME">CAME</th> <th nowrap="nowrap" align="center">Gemini Plugin</th>
<th nowrap="nowrap" align="center">Moe Hybrid Plugin</th>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td> <td nowrap="nowrap" align="center" title="Lamb">Lamb</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Low Level Zero<br />Plugin</td> <td nowrap="nowrap" align="center" title="GaLore">GaLore</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Torch DDP<br />Plugin</td> <td nowrap="nowrap" align="center" title="Adafactor">Adafactor</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Gemini<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td> <td nowrap="nowrap" align="center" title="CAME">CAME</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
@ -130,6 +128,7 @@ optim = DistGaloreAwamW(
</tr> </tr>
</table> </table>
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py --> <!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
## API 参考 ## API 参考

@ -51,7 +51,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr> <tr>

@ -52,9 +52,11 @@ class pretraining_dataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
[input_ids, input_mask, segment_ids, masked_lm_labels] = [ [input_ids, input_mask, segment_ids, masked_lm_labels] = [
(
torch.from_numpy(input[index].astype(np.int64)) torch.from_numpy(input[index].astype(np.int64))
if indice < 5 if indice < 5
else torch.from_numpy(np.asarray(input[index].astype(np.int64))) else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
)
for indice, input in enumerate(self.inputs) for indice, input in enumerate(self.inputs)
] ]

@ -229,9 +229,7 @@ class DDPM(pl.LightningModule):
) )
if self.parameterization == "eps": if self.parameterization == "eps":
lvlb_weights = self.betas**2 / ( lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)
)
elif self.parameterization == "x0": elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
elif self.parameterization == "v": elif self.parameterization == "v":
@ -1186,9 +1184,11 @@ class LatentDiffusion(DDPM):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = { cond = {
key: cond[key][:batch_size] key: (
cond[key][:batch_size]
if not isinstance(cond[key], list) if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key])) else list(map(lambda x: x[:batch_size], cond[key]))
)
for key in cond for key in cond
} }
else: else:
@ -1321,9 +1321,11 @@ class LatentDiffusion(DDPM):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = { cond = {
key: cond[key][:batch_size] key: (
cond[key][:batch_size]
if not isinstance(cond[key], list) if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key])) else list(map(lambda x: x[:batch_size], cond[key]))
)
for key in cond for key in cond
} }
else: else:

@ -1,4 +1,5 @@
"""SAMPLING ONLY.""" """SAMPLING ONLY."""
import torch import torch
from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper

@ -640,6 +640,7 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
), ),
(
AttentionBlock( AttentionBlock(
ch, ch,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
@ -657,6 +658,7 @@ class UNetModel(nn.Module):
disable_self_attn=disable_middle_self_attn, disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
)
), ),
ResBlock( ResBlock(
ch, ch,

@ -2,6 +2,7 @@
This file contains code that is adapted from This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn

@ -2,6 +2,7 @@
This file contains code that is adapted from This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn

@ -1,4 +1,5 @@
"""Utils for monoDepth.""" """Utils for monoDepth."""
import re import re
import sys import sys

@ -34,14 +34,14 @@ def swin_s():
# special output transform fn # special output transform fn
google_net_output_transform_fn = ( google_net_output_transform_fn = lambda x: (
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x) dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
) )
swin_s_output_output_transform_fn = ( swin_s_output_output_transform_fn = lambda x: (
lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
) )
inception_v3_output_transform_fn = ( inception_v3_output_transform_fn = lambda x: (
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x) dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
) )
model_zoo.register( model_zoo.register(

Loading…
Cancel
Save