mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'hpcaitech:main' into feature/fp8_comm
commit
6991819a97
|
@ -1,34 +1,34 @@
|
||||||
repos:
|
repos:
|
||||||
|
|
||||||
- repo: https://github.com/PyCQA/autoflake
|
- repo: https://github.com/PyCQA/autoflake
|
||||||
rev: v2.2.1
|
rev: v2.3.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: autoflake
|
- id: autoflake
|
||||||
name: autoflake (python)
|
name: autoflake (python)
|
||||||
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
|
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.12.0
|
rev: 5.13.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
name: sort all imports (python)
|
name: sort all imports (python)
|
||||||
|
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 23.9.1
|
rev: 24.4.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
name: black formatter
|
name: black formatter
|
||||||
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v13.0.1
|
rev: v18.1.8
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
name: clang formatter
|
name: clang formatter
|
||||||
types_or: [c++, c]
|
types_or: [c++, c]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.3.0
|
rev: v4.6.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-merge-conflict
|
- id: check-merge-conflict
|
||||||
|
|
|
@ -83,15 +83,19 @@ class DataCollatorForSupervisedDataset(object):
|
||||||
|
|
||||||
# `List[torch.Tensor]`
|
# `List[torch.Tensor]`
|
||||||
batch_input_ids = [
|
batch_input_ids = [
|
||||||
torch.LongTensor(instance["input_ids"][: self.max_length])
|
(
|
||||||
if len(instance["input_ids"]) > self.max_length
|
torch.LongTensor(instance["input_ids"][: self.max_length])
|
||||||
else torch.LongTensor(instance["input_ids"])
|
if len(instance["input_ids"]) > self.max_length
|
||||||
|
else torch.LongTensor(instance["input_ids"])
|
||||||
|
)
|
||||||
for instance in instances
|
for instance in instances
|
||||||
]
|
]
|
||||||
batch_labels = [
|
batch_labels = [
|
||||||
torch.LongTensor(instance["labels"][: self.max_length])
|
(
|
||||||
if len(instance["labels"]) > self.max_length
|
torch.LongTensor(instance["labels"][: self.max_length])
|
||||||
else torch.LongTensor(instance["labels"])
|
if len(instance["labels"]) > self.max_length
|
||||||
|
else torch.LongTensor(instance["labels"])
|
||||||
|
)
|
||||||
for instance in instances
|
for instance in instances
|
||||||
]
|
]
|
||||||
if self.tokenizer.padding_side == "right":
|
if self.tokenizer.padding_side == "right":
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
loss functions
|
loss functions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
reward model
|
reward model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Training utilities for Coati.
|
Training utilities for Coati.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -78,7 +78,9 @@ def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict
|
||||||
option_string = "ABCDEFG"
|
option_string = "ABCDEFG"
|
||||||
count = len(line["options"])
|
count = len(line["options"])
|
||||||
|
|
||||||
input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
|
input = (
|
||||||
|
"问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
|
||||||
|
)
|
||||||
|
|
||||||
all_classes = list(option_string[0:count])
|
all_classes = list(option_string[0:count])
|
||||||
|
|
||||||
|
@ -150,7 +152,15 @@ def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=F
|
||||||
)
|
)
|
||||||
elif dataset_name in chinese_qa_datasets:
|
elif dataset_name in chinese_qa_datasets:
|
||||||
question_input = (
|
question_input = (
|
||||||
"问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label)
|
"问题:"
|
||||||
|
+ passage
|
||||||
|
+ " "
|
||||||
|
+ question
|
||||||
|
+ "\n"
|
||||||
|
+ "从以下选项中选择:"
|
||||||
|
+ " ".join(options)
|
||||||
|
+ "\n"
|
||||||
|
+ "答案:{}".format(label)
|
||||||
)
|
)
|
||||||
elif dataset_name in english_cloze_datasets:
|
elif dataset_name in english_cloze_datasets:
|
||||||
question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)
|
question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)
|
||||||
|
|
|
@ -57,7 +57,11 @@ ceval_subject_mapping = {
|
||||||
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
|
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
|
||||||
"accountant": ["Accountant", "注册会计师", "Other"],
|
"accountant": ["Accountant", "注册会计师", "Other"],
|
||||||
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
|
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
|
||||||
"environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"],
|
"environmental_impact_assessment_engineer": [
|
||||||
|
"Environmental Impact Assessment Engineer",
|
||||||
|
"环境影响评价工程师",
|
||||||
|
"Other",
|
||||||
|
],
|
||||||
"tax_accountant": ["Tax Accountant", "税务师", "Other"],
|
"tax_accountant": ["Tax Accountant", "税务师", "Other"],
|
||||||
"physician": ["Physician", "医师资格", "Other"],
|
"physician": ["Physician", "医师资格", "Other"],
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,9 +56,11 @@ class MTBenchDataset(BaseDataset):
|
||||||
"instruction": question["turns"],
|
"instruction": question["turns"],
|
||||||
"input": "",
|
"input": "",
|
||||||
"output": [],
|
"output": [],
|
||||||
"target": [""] * turn_number
|
"target": (
|
||||||
if question["question_id"] not in reference
|
[""] * turn_number
|
||||||
else reference[question["question_id"]],
|
if question["question_id"] not in reference
|
||||||
|
else reference[question["question_id"]]
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
if category in dataset["test"]:
|
if category in dataset["test"]:
|
||||||
|
|
|
@ -77,7 +77,9 @@ class HuggingFaceModel(BaseModel):
|
||||||
self.indices_for_choices[0].append(
|
self.indices_for_choices[0].append(
|
||||||
self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
|
self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
|
||||||
)
|
)
|
||||||
self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1])
|
self.indices_for_choices[1].append(
|
||||||
|
self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]
|
||||||
|
)
|
||||||
|
|
||||||
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
|
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
|
||||||
https://github.com/langchain-ai/langchain
|
https://github.com/langchain-ai/langchain
|
||||||
The original code is licensed under the MIT license.
|
The original code is licensed under the MIT license.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
|
@ -8,6 +8,7 @@ This code is based on LangChain Ai's langchain, which can be found at
|
||||||
https://github.com/langchain-ai/langchain
|
https://github.com/langchain-ai/langchain
|
||||||
The original code is licensed under the MIT license.
|
The original code is licensed under the MIT license.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Any, Mapping, Optional, Protocol
|
from typing import Any, Mapping, Optional, Protocol
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
|
||||||
https://github.com/langchain-ai/langchain
|
https://github.com/langchain-ai/langchain
|
||||||
The original code is licensed under the MIT license.
|
The original code is licensed under the MIT license.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
|
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料
|
||||||
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
|
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, List, Mapping, Optional
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Generation utilities
|
Generation utilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
Implement a memory class for storing conversation history
|
Implement a memory class for storing conversation history
|
||||||
Support long term and short term memory
|
Support long term and short term memory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from colossalqa.chain.memory.summary import ConversationSummaryMemory
|
from colossalqa.chain.memory.summary import ConversationSummaryMemory
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Class for logging with extra control for debugging
|
Class for logging with extra control for debugging
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Multilingual retrieval based conversation system
|
Multilingual retrieval based conversation system
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from colossalqa.data_loader.document_loader import DocumentLoader
|
from colossalqa.data_loader.document_loader import DocumentLoader
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Code for custom retriver with incremental update
|
Code for custom retriver with incremental update
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Code for Chinese text splitter
|
Code for Chinese text splitter
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from colossalqa.text_splitter.utils import get_cleaned_paragraph
|
from colossalqa.text_splitter.utils import get_cleaned_paragraph
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Script for English retrieval based conversation system backed by LLaMa2
|
Script for English retrieval based conversation system backed by LLaMa2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Script for English retrieval based conversation system backed by LLaMa2
|
Script for English retrieval based conversation system backed by LLaMa2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Script for English retrieval based conversation system backed by LLaMa2
|
Script for English retrieval based conversation system backed by LLaMa2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
@ -107,20 +107,22 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||||
fwd_memory_cost = MemoryCost(
|
fwd_memory_cost = MemoryCost(
|
||||||
activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
parameter=(
|
||||||
if has_bias
|
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
|
||||||
else compute_size_in_bytes(weight_tensor),
|
),
|
||||||
temp=0,
|
temp=0,
|
||||||
buffer=0,
|
buffer=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
bwd_memory_cost = MemoryCost(
|
bwd_memory_cost = MemoryCost(
|
||||||
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
|
activation=(
|
||||||
if has_bias
|
compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
|
||||||
else compute_size_in_bytes([input_tensor, weight_tensor]),
|
if has_bias
|
||||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
else compute_size_in_bytes([input_tensor, weight_tensor])
|
||||||
if has_bias
|
),
|
||||||
else compute_size_in_bytes(weight_tensor),
|
parameter=(
|
||||||
|
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
|
||||||
|
),
|
||||||
temp=0,
|
temp=0,
|
||||||
buffer=0,
|
buffer=0,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,10 +1,18 @@
|
||||||
from .gemini_plugin import GeminiPlugin
|
from .gemini_plugin import GeminiPlugin
|
||||||
from .hybrid_parallel_plugin import HybridParallelPlugin
|
from .hybrid_parallel_plugin import HybridParallelPlugin
|
||||||
from .low_level_zero_plugin import LowLevelZeroPlugin
|
from .low_level_zero_plugin import LowLevelZeroPlugin
|
||||||
|
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from .plugin_base import Plugin
|
from .plugin_base import Plugin
|
||||||
from .torch_ddp_plugin import TorchDDPPlugin
|
from .torch_ddp_plugin import TorchDDPPlugin
|
||||||
|
|
||||||
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
|
__all__ = [
|
||||||
|
"Plugin",
|
||||||
|
"TorchDDPPlugin",
|
||||||
|
"GeminiPlugin",
|
||||||
|
"LowLevelZeroPlugin",
|
||||||
|
"HybridParallelPlugin",
|
||||||
|
"MoeHybridParallelPlugin",
|
||||||
|
]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
|
@ -247,16 +247,16 @@ class BatchBucket:
|
||||||
self._sequences_dict[seq.request_id] = seq
|
self._sequences_dict[seq.request_id] = seq
|
||||||
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
|
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
|
||||||
# TODO external (rename): modify Sequence.sentence_len to seq_len
|
# TODO external (rename): modify Sequence.sentence_len to seq_len
|
||||||
self._sequence_lengths[
|
self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
|
||||||
self._current_batch_size : self._current_batch_size + num_seqs_to_add
|
torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
|
||||||
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
|
)
|
||||||
# NOTE block tables to be updated by kvcache manager
|
# NOTE block tables to be updated by kvcache manager
|
||||||
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
|
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
|
||||||
if alloc_block_tables is not None:
|
if alloc_block_tables is not None:
|
||||||
# copy block ids from provided block tables
|
# copy block ids from provided block tables
|
||||||
self._block_tables[
|
self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
|
||||||
self._current_batch_size : self._current_batch_size + num_seqs_to_add
|
alloc_block_tables
|
||||||
] = alloc_block_tables
|
)
|
||||||
elif alloc_block_tables_fn:
|
elif alloc_block_tables_fn:
|
||||||
alloc_block_tables_fn(
|
alloc_block_tables_fn(
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
|
@ -82,9 +83,9 @@ class InputMetaData(RPC_PARAM):
|
||||||
dtype: torch.dtype = torch.float32
|
dtype: torch.dtype = torch.float32
|
||||||
use_spec_dec: bool = False
|
use_spec_dec: bool = False
|
||||||
num_tokens_to_verify: int = 0
|
num_tokens_to_verify: int = 0
|
||||||
batch_token_ids: Optional[
|
batch_token_ids: Optional[List[List[int]]] = (
|
||||||
List[List[int]]
|
None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
|
||||||
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
|
)
|
||||||
|
|
||||||
def to_rpc_param(self) -> Dict[str, any]:
|
def to_rpc_param(self) -> Dict[str, any]:
|
||||||
return {
|
return {
|
||||||
|
@ -202,9 +203,9 @@ class InferenceConfig(RPC_PARAM):
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
beam_width: int = 1 # TODO: beam search is not support for now
|
beam_width: int = 1 # TODO: beam search is not support for now
|
||||||
prefill_ratio: Optional[
|
prefill_ratio: Optional[float] = (
|
||||||
float
|
1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||||
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
)
|
||||||
pad_input: bool = False
|
pad_input: bool = False
|
||||||
early_stopping: Optional[bool] = False
|
early_stopping: Optional[bool] = False
|
||||||
top_k: Optional[int] = 50
|
top_k: Optional[int] = 50
|
||||||
|
@ -234,7 +235,9 @@ class InferenceConfig(RPC_PARAM):
|
||||||
high_precision: Optional[bool] = False
|
high_precision: Optional[bool] = False
|
||||||
|
|
||||||
# cuda_graph
|
# cuda_graph
|
||||||
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
use_cuda_graph: bool = (
|
||||||
|
False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||||
|
)
|
||||||
max_context_len_to_capture: int = 512
|
max_context_len_to_capture: int = 512
|
||||||
|
|
||||||
# StreamingLLM (sliding window attention with attention sinks)
|
# StreamingLLM (sliding window attention with attention sinks)
|
||||||
|
|
|
@ -47,7 +47,6 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||||
|
|
||||||
|
|
||||||
class InferenceEngine:
|
class InferenceEngine:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
InferenceEngine which manages the inference process..
|
InferenceEngine which manages the inference process..
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,6 @@ def run_server(host, port, event: mp.Event = None):
|
||||||
|
|
||||||
|
|
||||||
class RPCInferenceEngine(InferenceEngine):
|
class RPCInferenceEngine(InferenceEngine):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
InferenceEngine which manages the inference process..
|
InferenceEngine which manages the inference process..
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,6 @@ logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class rpcWorkerService(rpyc.Service):
|
class rpcWorkerService(rpyc.Service):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Execute the computation tasks and manage its own kv cache
|
Execute the computation tasks and manage its own kv cache
|
||||||
|
|
||||||
|
|
|
@ -279,9 +279,11 @@ class KVCacheManager:
|
||||||
block.add_ref()
|
block.add_ref()
|
||||||
self._allocate_on_block(
|
self._allocate_on_block(
|
||||||
block,
|
block,
|
||||||
block.block_size
|
(
|
||||||
if context_lengths[i] % block.block_size == 0
|
block.block_size
|
||||||
else context_lengths[i].item() % block.block_size,
|
if context_lengths[i] % block.block_size == 0
|
||||||
|
else context_lengths[i].item() % block.block_size
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for block_id in alloc_block_ids:
|
for block_id in alloc_block_ids:
|
||||||
if block_id in alloc_block_ids[last_block_locs]:
|
if block_id in alloc_block_ids[last_block_locs]:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Utils for model inference
|
Utils for model inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
|
@ -138,9 +138,7 @@ class Initializer_2D(ProcessGroupInitializer):
|
||||||
self.num_group = self.world_size // self.tensor_parallel_size
|
self.num_group = self.world_size // self.tensor_parallel_size
|
||||||
self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
|
self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
|
||||||
|
|
||||||
assert (
|
assert self.tensor_parallel_size == self.summa_dim**2, "2D summa dim should equal to tensor parallel size ^ 0.5"
|
||||||
self.tensor_parallel_size == self.summa_dim**2
|
|
||||||
), "2D summa dim should equal to tensor parallel size ^ 0.5"
|
|
||||||
_check_summa_env_var(self.summa_dim)
|
_check_summa_env_var(self.summa_dim)
|
||||||
|
|
||||||
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)
|
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)
|
||||||
|
|
|
@ -54,7 +54,6 @@ class RequestTracker:
|
||||||
|
|
||||||
|
|
||||||
class Async_Engine:
|
class Async_Engine:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
|
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
|
||||||
Background loop: inference reqs in waiting list (Listen)
|
Background loop: inference reqs in waiting list (Listen)
|
||||||
|
|
|
@ -118,16 +118,16 @@ class Batch:
|
||||||
|
|
||||||
class BatchTokenIdOut:
|
class BatchTokenIdOut:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reqs_infs: List[
|
self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = (
|
||||||
Tuple[str, int, Dict, bool, bool]
|
[]
|
||||||
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
|
) # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
|
||||||
|
|
||||||
|
|
||||||
class BatchStrOut:
|
class BatchStrOut:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reqs_infs: List[
|
self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = (
|
||||||
Tuple[str, str, Dict, bool, bool]
|
[]
|
||||||
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
|
) # [req_id, token_str, gen_metadata, finished_state, abort_state]
|
||||||
|
|
||||||
|
|
||||||
class AbortReq:
|
class AbortReq:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Utils for model inference
|
Utils for model inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -14,6 +14,7 @@ class BatchInferState:
|
||||||
Information to be passed and used for a batch of inputs during
|
Information to be passed and used for a batch of inputs during
|
||||||
a single model forward
|
a single model forward
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch_size: int
|
batch_size: int
|
||||||
max_len_in_batch: int
|
max_len_in_batch: int
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ of the ModelTC/lightllm GitHub repository
|
||||||
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
|
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
|
||||||
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
|
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Utils for model inference
|
Utils for model inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,17 +1,25 @@
|
||||||
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
|
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
|
||||||
|
|
||||||
|
import importlib.metadata
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
from .bnb_config import BnbQuantizationConfig
|
from .bnb_config import BnbQuantizationConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
|
try:
|
||||||
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
|
# in case lower version of bitsandbytes does not have __version__ attribute
|
||||||
|
BNB_VERSION = Version(bnb.__version__)
|
||||||
|
except AttributeError:
|
||||||
|
BNB_VERSION = Version(importlib.metadata.version("bitsandbytes"))
|
||||||
|
|
||||||
|
IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0")
|
||||||
|
IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ This license shall be governed and construed in accordance with the laws of Peop
|
||||||
|
|
||||||
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
|
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
""" PyTorch ChatGLM model. """
|
""" PyTorch ChatGLM model. """
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
|
@ -221,7 +221,7 @@ class OPTPipelineForwards:
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if decoder.gradient_checkpointing and decoder.training:
|
if decoder.gradient_checkpointing and decoder.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self.decoder._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_attention_mask,
|
causal_attention_mask,
|
||||||
|
|
|
@ -168,13 +168,27 @@ class Qwen2PipelineForwards:
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
num_ckpt_layers = 0
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
num_ckpt_layers = end_idx - start_idx
|
||||||
|
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||||
|
if shard_config.gradient_checkpoint_config is not None:
|
||||||
|
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||||
|
stage=stage_manager.stage,
|
||||||
|
num_stages=stage_manager.num_stages,
|
||||||
|
num_layers=end_idx - start_idx,
|
||||||
|
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||||
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
|
)
|
||||||
|
assert num_ckpt_layers <= end_idx - start_idx
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if idx - start_idx < num_ckpt_layers:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -198,7 +212,6 @@ class Qwen2PipelineForwards:
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
|
|
@ -40,21 +40,19 @@ class MixtralPolicy(Policy):
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||||
if getattr(self.shard_config, "ep_group", None) is None:
|
if getattr(self.shard_config, "ep_group", None) is not None:
|
||||||
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
|
# expert parallel
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
# expert parallel
|
description=[
|
||||||
self.append_or_create_submodule_replacement(
|
SubModuleReplacementDescription(
|
||||||
description=[
|
suffix="block_sparse_moe",
|
||||||
SubModuleReplacementDescription(
|
target_module=EPMixtralSparseMoeBlock,
|
||||||
suffix="block_sparse_moe",
|
kwargs={"ep_group": self.shard_config.ep_group},
|
||||||
target_module=EPMixtralSparseMoeBlock,
|
)
|
||||||
kwargs={"ep_group": self.shard_config.ep_group},
|
],
|
||||||
)
|
policy=policy,
|
||||||
],
|
target_key=MixtralDecoderLayer,
|
||||||
policy=policy,
|
)
|
||||||
target_key=MixtralDecoderLayer,
|
|
||||||
)
|
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
|
|
@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
working_param = real_working_params[group_id][idx]
|
working_param = real_working_params[group_id][idx]
|
||||||
param_to_gather = master_param.to(device).to(self._dtype)
|
param_to_gather = master_param.to(device).to(self._dtype)
|
||||||
pg = self.param_to_pg[working_param]
|
pg = self.param_to_pg[working_param]
|
||||||
|
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||||
|
buffer_tensor = torch.empty_like(
|
||||||
|
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
|
||||||
|
)
|
||||||
|
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
|
||||||
|
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
|
|
@ -87,44 +87,42 @@ optim = DistGaloreAwamW(
|
||||||
## Plugin compatibility
|
## Plugin compatibility
|
||||||
<table>
|
<table>
|
||||||
<tr>
|
<tr>
|
||||||
<th nowrap="nowrap">Model/Feature</th>
|
<th nowrap="nowrap">Optimizer/Plugin</th>
|
||||||
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th>
|
<th nowrap="nowrap" align="center">Hybrid Parallel Plugin</th>
|
||||||
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th>
|
<th nowrap="nowrap" align="center">Low Level Zero Plugin</th>
|
||||||
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th>
|
<th nowrap="nowrap" align="center">Torch DDP Plugin</th>
|
||||||
<th nowrap="nowrap" align="center" title="CAME">CAME</th>
|
<th nowrap="nowrap" align="center">Gemini Plugin</th>
|
||||||
|
<th nowrap="nowrap" align="center">Moe Hybrid Plugin</th>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td>
|
<td nowrap="nowrap" align="center" title="Lamb">Lamb</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td nowrap="nowrap">Gemini<br />Plugin</td>
|
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td>
|
<td nowrap="nowrap" align="center" title="GaLore">GaLore</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td nowrap="nowrap" align="center" title="Adafactor">Adafactor</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td nowrap="nowrap" align="center" title="CAME">CAME</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
|
@ -55,7 +55,7 @@ Model/Feature Compatibility Matrix:
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
|
|
|
@ -84,44 +84,42 @@ optim = DistGaloreAwamW(
|
||||||
## 兼容性
|
## 兼容性
|
||||||
<table>
|
<table>
|
||||||
<tr>
|
<tr>
|
||||||
<th nowrap="nowrap">Model/Feature</th>
|
<th nowrap="nowrap">Optimizer/Plugin</th>
|
||||||
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th>
|
<th nowrap="nowrap" align="center">Hybrid Parallel Plugin</th>
|
||||||
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th>
|
<th nowrap="nowrap" align="center">Low Level Zero Plugin</th>
|
||||||
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th>
|
<th nowrap="nowrap" align="center">Torch DDP Plugin</th>
|
||||||
<th nowrap="nowrap" align="center" title="CAME">CAME</th>
|
<th nowrap="nowrap" align="center">Gemini Plugin</th>
|
||||||
|
<th nowrap="nowrap" align="center">Moe Hybrid Plugin</th>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td>
|
<td nowrap="nowrap" align="center" title="Lamb">Lamb</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td nowrap="nowrap">Gemini<br />Plugin</td>
|
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td>
|
<td nowrap="nowrap" align="center" title="GaLore">GaLore</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td nowrap="nowrap" align="center" title="Adafactor">Adafactor</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td nowrap="nowrap" align="center" title="CAME">CAME</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
@ -130,6 +128,7 @@ optim = DistGaloreAwamW(
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
|
|
||||||
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
|
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
|
||||||
|
|
||||||
## API 参考
|
## API 参考
|
||||||
|
|
|
@ -51,7 +51,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">✔️</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">✔️</td>
|
||||||
<td nowrap="nowrap" align="center">❌</td>
|
<td nowrap="nowrap" align="center">❌</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
|
|
|
@ -52,9 +52,11 @@ class pretraining_dataset(Dataset):
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
[input_ids, input_mask, segment_ids, masked_lm_labels] = [
|
[input_ids, input_mask, segment_ids, masked_lm_labels] = [
|
||||||
torch.from_numpy(input[index].astype(np.int64))
|
(
|
||||||
if indice < 5
|
torch.from_numpy(input[index].astype(np.int64))
|
||||||
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
|
if indice < 5
|
||||||
|
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
|
||||||
|
)
|
||||||
for indice, input in enumerate(self.inputs)
|
for indice, input in enumerate(self.inputs)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -229,9 +229,7 @@ class DDPM(pl.LightningModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.parameterization == "eps":
|
if self.parameterization == "eps":
|
||||||
lvlb_weights = self.betas**2 / (
|
lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
||||||
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)
|
|
||||||
)
|
|
||||||
elif self.parameterization == "x0":
|
elif self.parameterization == "x0":
|
||||||
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
||||||
elif self.parameterization == "v":
|
elif self.parameterization == "v":
|
||||||
|
@ -1186,9 +1184,11 @@ class LatentDiffusion(DDPM):
|
||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {
|
cond = {
|
||||||
key: cond[key][:batch_size]
|
key: (
|
||||||
if not isinstance(cond[key], list)
|
cond[key][:batch_size]
|
||||||
else list(map(lambda x: x[:batch_size], cond[key]))
|
if not isinstance(cond[key], list)
|
||||||
|
else list(map(lambda x: x[:batch_size], cond[key]))
|
||||||
|
)
|
||||||
for key in cond
|
for key in cond
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
@ -1321,9 +1321,11 @@ class LatentDiffusion(DDPM):
|
||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {
|
cond = {
|
||||||
key: cond[key][:batch_size]
|
key: (
|
||||||
if not isinstance(cond[key], list)
|
cond[key][:batch_size]
|
||||||
else list(map(lambda x: x[:batch_size], cond[key]))
|
if not isinstance(cond[key], list)
|
||||||
|
else list(map(lambda x: x[:batch_size], cond[key]))
|
||||||
|
)
|
||||||
for key in cond
|
for key in cond
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""SAMPLING ONLY."""
|
"""SAMPLING ONLY."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
||||||
|
|
|
@ -640,23 +640,25 @@ class UNetModel(nn.Module):
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
(
|
||||||
ch,
|
AttentionBlock(
|
||||||
use_checkpoint=use_checkpoint,
|
ch,
|
||||||
num_heads=num_heads,
|
use_checkpoint=use_checkpoint,
|
||||||
num_head_channels=dim_head,
|
num_heads=num_heads,
|
||||||
use_new_attention_order=use_new_attention_order,
|
num_head_channels=dim_head,
|
||||||
)
|
use_new_attention_order=use_new_attention_order,
|
||||||
if not use_spatial_transformer
|
)
|
||||||
else SpatialTransformer( # always uses a self-attn
|
if not use_spatial_transformer
|
||||||
ch,
|
else SpatialTransformer( # always uses a self-attn
|
||||||
num_heads,
|
ch,
|
||||||
dim_head,
|
num_heads,
|
||||||
depth=transformer_depth,
|
dim_head,
|
||||||
context_dim=context_dim,
|
depth=transformer_depth,
|
||||||
disable_self_attn=disable_middle_self_attn,
|
context_dim=context_dim,
|
||||||
use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn,
|
||||||
use_checkpoint=use_checkpoint,
|
use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
)
|
||||||
),
|
),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
This file contains code that is adapted from
|
This file contains code that is adapted from
|
||||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
This file contains code that is adapted from
|
This file contains code that is adapted from
|
||||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Utils for monoDepth."""
|
"""Utils for monoDepth."""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
@ -369,9 +369,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
|
||||||
}
|
}
|
||||||
|
|
||||||
} // for (auto sent_index=sent_index_first; ...
|
} // for (auto sent_index=sent_index_first; ...
|
||||||
} // if (num_remain_sent > 1) {
|
} // if (num_remain_sent > 1) {
|
||||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||||
|
|
||||||
if (!second) {
|
if (!second) {
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
|
@ -606,9 +606,9 @@ py::array build_blocks_mapping_impl(
|
||||||
num_sent = 0;
|
num_sent = 0;
|
||||||
}
|
}
|
||||||
} // for (auto sent_index=sent_index_first; ...
|
} // for (auto sent_index=sent_index_first; ...
|
||||||
} // if (num_remain_sent > 1) {
|
} // if (num_remain_sent > 1) {
|
||||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||||
|
|
||||||
if (!second) {
|
if (!second) {
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
|
||||||
#define TILE (128 * 1024 * 1024)
|
#define TILE (128 * 1024 * 1024)
|
||||||
|
|
||||||
#if defined(__aarch64__)
|
#if defined(__aarch64__)
|
||||||
|
|
|
@ -32,7 +32,7 @@ SOFTWARE
|
||||||
#include <x86intrin.h>
|
#include <x86intrin.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
|
||||||
#define TILE (128 * 1024 * 1024)
|
#define TILE (128 * 1024 * 1024)
|
||||||
|
|
||||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||||
|
|
|
@ -34,14 +34,14 @@ def swin_s():
|
||||||
|
|
||||||
|
|
||||||
# special output transform fn
|
# special output transform fn
|
||||||
google_net_output_transform_fn = (
|
google_net_output_transform_fn = lambda x: (
|
||||||
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
|
dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
|
||||||
)
|
)
|
||||||
swin_s_output_output_transform_fn = (
|
swin_s_output_output_transform_fn = lambda x: (
|
||||||
lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
|
{f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
|
||||||
)
|
)
|
||||||
inception_v3_output_transform_fn = (
|
inception_v3_output_transform_fn = lambda x: (
|
||||||
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
|
dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
|
||||||
)
|
)
|
||||||
|
|
||||||
model_zoo.register(
|
model_zoo.register(
|
||||||
|
|
Loading…
Reference in New Issue