mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'hpcaitech:main' into feature/fp8_comm
commit
6991819a97
|
@ -1,34 +1,34 @@
|
|||
repos:
|
||||
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: v2.2.1
|
||||
rev: v2.3.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
name: autoflake (python)
|
||||
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
name: sort all imports (python)
|
||||
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 23.9.1
|
||||
rev: 24.4.2
|
||||
hooks:
|
||||
- id: black
|
||||
name: black formatter
|
||||
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v13.0.1
|
||||
rev: v18.1.8
|
||||
hooks:
|
||||
- id: clang-format
|
||||
name: clang formatter
|
||||
types_or: [c++, c]
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
rev: v4.6.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: check-merge-conflict
|
||||
|
|
|
@ -83,15 +83,19 @@ class DataCollatorForSupervisedDataset(object):
|
|||
|
||||
# `List[torch.Tensor]`
|
||||
batch_input_ids = [
|
||||
torch.LongTensor(instance["input_ids"][: self.max_length])
|
||||
if len(instance["input_ids"]) > self.max_length
|
||||
else torch.LongTensor(instance["input_ids"])
|
||||
(
|
||||
torch.LongTensor(instance["input_ids"][: self.max_length])
|
||||
if len(instance["input_ids"]) > self.max_length
|
||||
else torch.LongTensor(instance["input_ids"])
|
||||
)
|
||||
for instance in instances
|
||||
]
|
||||
batch_labels = [
|
||||
torch.LongTensor(instance["labels"][: self.max_length])
|
||||
if len(instance["labels"]) > self.max_length
|
||||
else torch.LongTensor(instance["labels"])
|
||||
(
|
||||
torch.LongTensor(instance["labels"][: self.max_length])
|
||||
if len(instance["labels"]) > self.max_length
|
||||
else torch.LongTensor(instance["labels"])
|
||||
)
|
||||
for instance in instances
|
||||
]
|
||||
if self.tokenizer.padding_side == "right":
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
loss functions
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
reward model
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Training utilities for Coati.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
|
|
@ -78,7 +78,9 @@ def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict
|
|||
option_string = "ABCDEFG"
|
||||
count = len(line["options"])
|
||||
|
||||
input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
|
||||
input = (
|
||||
"问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
|
||||
)
|
||||
|
||||
all_classes = list(option_string[0:count])
|
||||
|
||||
|
@ -150,7 +152,15 @@ def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=F
|
|||
)
|
||||
elif dataset_name in chinese_qa_datasets:
|
||||
question_input = (
|
||||
"问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label)
|
||||
"问题:"
|
||||
+ passage
|
||||
+ " "
|
||||
+ question
|
||||
+ "\n"
|
||||
+ "从以下选项中选择:"
|
||||
+ " ".join(options)
|
||||
+ "\n"
|
||||
+ "答案:{}".format(label)
|
||||
)
|
||||
elif dataset_name in english_cloze_datasets:
|
||||
question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)
|
||||
|
|
|
@ -57,7 +57,11 @@ ceval_subject_mapping = {
|
|||
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
|
||||
"accountant": ["Accountant", "注册会计师", "Other"],
|
||||
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
|
||||
"environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"],
|
||||
"environmental_impact_assessment_engineer": [
|
||||
"Environmental Impact Assessment Engineer",
|
||||
"环境影响评价工程师",
|
||||
"Other",
|
||||
],
|
||||
"tax_accountant": ["Tax Accountant", "税务师", "Other"],
|
||||
"physician": ["Physician", "医师资格", "Other"],
|
||||
}
|
||||
|
|
|
@ -56,9 +56,11 @@ class MTBenchDataset(BaseDataset):
|
|||
"instruction": question["turns"],
|
||||
"input": "",
|
||||
"output": [],
|
||||
"target": [""] * turn_number
|
||||
if question["question_id"] not in reference
|
||||
else reference[question["question_id"]],
|
||||
"target": (
|
||||
[""] * turn_number
|
||||
if question["question_id"] not in reference
|
||||
else reference[question["question_id"]]
|
||||
),
|
||||
}
|
||||
|
||||
if category in dataset["test"]:
|
||||
|
|
|
@ -77,7 +77,9 @@ class HuggingFaceModel(BaseModel):
|
|||
self.indices_for_choices[0].append(
|
||||
self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
|
||||
)
|
||||
self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1])
|
||||
self.indices_for_choices[1].append(
|
||||
self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]
|
||||
)
|
||||
|
||||
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
|
||||
"""
|
||||
|
|
|
@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
|
|||
https://github.com/langchain-ai/langchain
|
||||
The original code is licensed under the MIT license.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
|
|
|
@ -8,6 +8,7 @@ This code is based on LangChain Ai's langchain, which can be found at
|
|||
https://github.com/langchain-ai/langchain
|
||||
The original code is licensed under the MIT license.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
|
|||
https://github.com/langchain-ai/langchain
|
||||
The original code is licensed under the MIT license.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, List
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
|
||||
"""
|
||||
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料
|
|||
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
import torch
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Generation utilities
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Implement a memory class for storing conversation history
|
||||
Support long term and short term memory
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from colossalqa.chain.memory.summary import ConversationSummaryMemory
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Class for logging with extra control for debugging
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Multilingual retrieval based conversation system
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from colossalqa.data_loader.document_loader import DocumentLoader
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Code for custom retriver with incremental update
|
||||
"""
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Code for Chinese text splitter
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from colossalqa.text_splitter.utils import get_cleaned_paragraph
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for English retrieval based conversation system backed by LLaMa2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for English retrieval based conversation system backed by LLaMa2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for Chinese retrieval based conversation system backed by ChatGLM
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Script for English retrieval based conversation system backed by LLaMa2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
|
|
@ -107,20 +107,22 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias
|
||||
else compute_size_in_bytes(weight_tensor),
|
||||
parameter=(
|
||||
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
|
||||
),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
|
||||
if has_bias
|
||||
else compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias
|
||||
else compute_size_in_bytes(weight_tensor),
|
||||
activation=(
|
||||
compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
|
||||
if has_bias
|
||||
else compute_size_in_bytes([input_tensor, weight_tensor])
|
||||
),
|
||||
parameter=(
|
||||
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
|
||||
),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
|
|
@ -1,10 +1,18 @@
|
|||
from .gemini_plugin import GeminiPlugin
|
||||
from .hybrid_parallel_plugin import HybridParallelPlugin
|
||||
from .low_level_zero_plugin import LowLevelZeroPlugin
|
||||
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from .plugin_base import Plugin
|
||||
from .torch_ddp_plugin import TorchDDPPlugin
|
||||
|
||||
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
|
||||
__all__ = [
|
||||
"Plugin",
|
||||
"TorchDDPPlugin",
|
||||
"GeminiPlugin",
|
||||
"LowLevelZeroPlugin",
|
||||
"HybridParallelPlugin",
|
||||
"MoeHybridParallelPlugin",
|
||||
]
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
|
|
@ -247,16 +247,16 @@ class BatchBucket:
|
|||
self._sequences_dict[seq.request_id] = seq
|
||||
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
|
||||
# TODO external (rename): modify Sequence.sentence_len to seq_len
|
||||
self._sequence_lengths[
|
||||
self._current_batch_size : self._current_batch_size + num_seqs_to_add
|
||||
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
|
||||
self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
|
||||
torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
|
||||
)
|
||||
# NOTE block tables to be updated by kvcache manager
|
||||
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
|
||||
if alloc_block_tables is not None:
|
||||
# copy block ids from provided block tables
|
||||
self._block_tables[
|
||||
self._current_batch_size : self._current_batch_size + num_seqs_to_add
|
||||
] = alloc_block_tables
|
||||
self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
|
||||
alloc_block_tables
|
||||
)
|
||||
elif alloc_block_tables_fn:
|
||||
alloc_block_tables_fn(
|
||||
block_tables,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields
|
||||
|
@ -82,9 +83,9 @@ class InputMetaData(RPC_PARAM):
|
|||
dtype: torch.dtype = torch.float32
|
||||
use_spec_dec: bool = False
|
||||
num_tokens_to_verify: int = 0
|
||||
batch_token_ids: Optional[
|
||||
List[List[int]]
|
||||
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
|
||||
batch_token_ids: Optional[List[List[int]]] = (
|
||||
None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
|
||||
)
|
||||
|
||||
def to_rpc_param(self) -> Dict[str, any]:
|
||||
return {
|
||||
|
@ -202,9 +203,9 @@ class InferenceConfig(RPC_PARAM):
|
|||
prompt_template: Optional[str] = None
|
||||
do_sample: bool = False
|
||||
beam_width: int = 1 # TODO: beam search is not support for now
|
||||
prefill_ratio: Optional[
|
||||
float
|
||||
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||
prefill_ratio: Optional[float] = (
|
||||
1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||
)
|
||||
pad_input: bool = False
|
||||
early_stopping: Optional[bool] = False
|
||||
top_k: Optional[int] = 50
|
||||
|
@ -234,7 +235,9 @@ class InferenceConfig(RPC_PARAM):
|
|||
high_precision: Optional[bool] = False
|
||||
|
||||
# cuda_graph
|
||||
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph: bool = (
|
||||
False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
)
|
||||
max_context_len_to_capture: int = 512
|
||||
|
||||
# StreamingLLM (sliding window attention with attention sinks)
|
||||
|
|
|
@ -47,7 +47,6 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
|||
|
||||
|
||||
class InferenceEngine:
|
||||
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
|
|
|
@ -34,7 +34,6 @@ def run_server(host, port, event: mp.Event = None):
|
|||
|
||||
|
||||
class RPCInferenceEngine(InferenceEngine):
|
||||
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
|
|
|
@ -42,7 +42,6 @@ logger = get_dist_logger(__name__)
|
|||
|
||||
|
||||
class rpcWorkerService(rpyc.Service):
|
||||
|
||||
"""
|
||||
Execute the computation tasks and manage its own kv cache
|
||||
|
||||
|
|
|
@ -279,9 +279,11 @@ class KVCacheManager:
|
|||
block.add_ref()
|
||||
self._allocate_on_block(
|
||||
block,
|
||||
block.block_size
|
||||
if context_lengths[i] % block.block_size == 0
|
||||
else context_lengths[i].item() % block.block_size,
|
||||
(
|
||||
block.block_size
|
||||
if context_lengths[i] % block.block_size == 0
|
||||
else context_lengths[i].item() % block.block_size
|
||||
),
|
||||
)
|
||||
for block_id in alloc_block_ids:
|
||||
if block_id in alloc_block_ids[last_block_locs]:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
|
|
|
@ -138,9 +138,7 @@ class Initializer_2D(ProcessGroupInitializer):
|
|||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
|
||||
|
||||
assert (
|
||||
self.tensor_parallel_size == self.summa_dim**2
|
||||
), "2D summa dim should equal to tensor parallel size ^ 0.5"
|
||||
assert self.tensor_parallel_size == self.summa_dim**2, "2D summa dim should equal to tensor parallel size ^ 0.5"
|
||||
_check_summa_env_var(self.summa_dim)
|
||||
|
||||
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)
|
||||
|
|
|
@ -54,7 +54,6 @@ class RequestTracker:
|
|||
|
||||
|
||||
class Async_Engine:
|
||||
|
||||
"""
|
||||
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
|
||||
Background loop: inference reqs in waiting list (Listen)
|
||||
|
|
|
@ -118,16 +118,16 @@ class Batch:
|
|||
|
||||
class BatchTokenIdOut:
|
||||
def __init__(self):
|
||||
self.reqs_infs: List[
|
||||
Tuple[str, int, Dict, bool, bool]
|
||||
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
|
||||
self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = (
|
||||
[]
|
||||
) # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
|
||||
|
||||
|
||||
class BatchStrOut:
|
||||
def __init__(self):
|
||||
self.reqs_infs: List[
|
||||
Tuple[str, str, Dict, bool, bool]
|
||||
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
|
||||
self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = (
|
||||
[]
|
||||
) # [req_id, token_str, gen_metadata, finished_state, abort_state]
|
||||
|
||||
|
||||
class AbortReq:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
|
|
@ -14,6 +14,7 @@ class BatchInferState:
|
|||
Information to be passed and used for a batch of inputs during
|
||||
a single model forward
|
||||
"""
|
||||
|
||||
batch_size: int
|
||||
max_len_in_batch: int
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ of the ModelTC/lightllm GitHub repository
|
|||
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
|
||||
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers.utils import logging
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
|
|
@ -1,17 +1,25 @@
|
|||
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
|
||||
|
||||
import importlib.metadata
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging.version import Version
|
||||
|
||||
from .bnb_config import BnbQuantizationConfig
|
||||
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
|
||||
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
|
||||
try:
|
||||
# in case lower version of bitsandbytes does not have __version__ attribute
|
||||
BNB_VERSION = Version(bnb.__version__)
|
||||
except AttributeError:
|
||||
BNB_VERSION = Version(importlib.metadata.version("bitsandbytes"))
|
||||
|
||||
IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0")
|
||||
IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ This license shall be governed and construed in accordance with the laws of Peop
|
|||
|
||||
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
|
||||
"""
|
||||
|
||||
""" PyTorch ChatGLM model. """
|
||||
|
||||
import copy
|
||||
|
|
|
@ -221,7 +221,7 @@ class OPTPipelineForwards:
|
|||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if decoder.gradient_checkpointing and decoder.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_outputs = self.decoder._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_attention_mask,
|
||||
|
|
|
@ -168,13 +168,27 @@ class Qwen2PipelineForwards:
|
|||
next_decoder_cache = None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
num_ckpt_layers = end_idx - start_idx
|
||||
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_stages=stage_manager.num_stages,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if idx - start_idx < num_ckpt_layers:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
|
@ -198,7 +212,6 @@ class Qwen2PipelineForwards:
|
|||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
|
|
|
@ -40,21 +40,19 @@ class MixtralPolicy(Policy):
|
|||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||
if getattr(self.shard_config, "ep_group", None) is None:
|
||||
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
|
||||
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
kwargs={"ep_group": self.shard_config.ep_group},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
if getattr(self.shard_config, "ep_group", None) is not None:
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
kwargs={"ep_group": self.shard_config.ep_group},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
|
|
|
@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
working_param = real_working_params[group_id][idx]
|
||||
param_to_gather = master_param.to(device).to(self._dtype)
|
||||
pg = self.param_to_pg[working_param]
|
||||
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||
buffer_tensor = torch.empty_like(
|
||||
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
|
||||
)
|
||||
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
|
||||
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
|
||||
continue
|
||||
try:
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
except RuntimeError:
|
||||
|
|
|
@ -87,44 +87,42 @@ optim = DistGaloreAwamW(
|
|||
## Plugin compatibility
|
||||
<table>
|
||||
<tr>
|
||||
<th nowrap="nowrap">Model/Feature</th>
|
||||
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th>
|
||||
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th>
|
||||
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th>
|
||||
<th nowrap="nowrap" align="center" title="CAME">CAME</th>
|
||||
<th nowrap="nowrap">Optimizer/Plugin</th>
|
||||
<th nowrap="nowrap" align="center">Hybrid Parallel Plugin</th>
|
||||
<th nowrap="nowrap" align="center">Low Level Zero Plugin</th>
|
||||
<th nowrap="nowrap" align="center">Torch DDP Plugin</th>
|
||||
<th nowrap="nowrap" align="center">Gemini Plugin</th>
|
||||
<th nowrap="nowrap" align="center">Moe Hybrid Plugin</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td>
|
||||
<td nowrap="nowrap" align="center" title="Lamb">Lamb</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Gemini<br />Plugin</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td>
|
||||
<td nowrap="nowrap" align="center" title="GaLore">GaLore</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="center" title="Adafactor">Adafactor</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="center" title="CAME">CAME</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
|
|
|
@ -55,7 +55,7 @@ Model/Feature Compatibility Matrix:
|
|||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
|
|
@ -84,44 +84,42 @@ optim = DistGaloreAwamW(
|
|||
## 兼容性
|
||||
<table>
|
||||
<tr>
|
||||
<th nowrap="nowrap">Model/Feature</th>
|
||||
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th>
|
||||
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th>
|
||||
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th>
|
||||
<th nowrap="nowrap" align="center" title="CAME">CAME</th>
|
||||
<th nowrap="nowrap">Optimizer/Plugin</th>
|
||||
<th nowrap="nowrap" align="center">Hybrid Parallel Plugin</th>
|
||||
<th nowrap="nowrap" align="center">Low Level Zero Plugin</th>
|
||||
<th nowrap="nowrap" align="center">Torch DDP Plugin</th>
|
||||
<th nowrap="nowrap" align="center">Gemini Plugin</th>
|
||||
<th nowrap="nowrap" align="center">Moe Hybrid Plugin</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td>
|
||||
<td nowrap="nowrap" align="center" title="Lamb">Lamb</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Gemini<br />Plugin</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td>
|
||||
<td nowrap="nowrap" align="center" title="GaLore">GaLore</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="center" title="Adafactor">Adafactor</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="center" title="CAME">CAME</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
|
@ -130,6 +128,7 @@ optim = DistGaloreAwamW(
|
|||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
|
||||
|
||||
## API 参考
|
||||
|
|
|
@ -51,7 +51,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
|
|||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
|
|
@ -52,9 +52,11 @@ class pretraining_dataset(Dataset):
|
|||
|
||||
def __getitem__(self, index):
|
||||
[input_ids, input_mask, segment_ids, masked_lm_labels] = [
|
||||
torch.from_numpy(input[index].astype(np.int64))
|
||||
if indice < 5
|
||||
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
|
||||
(
|
||||
torch.from_numpy(input[index].astype(np.int64))
|
||||
if indice < 5
|
||||
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
|
||||
)
|
||||
for indice, input in enumerate(self.inputs)
|
||||
]
|
||||
|
||||
|
|
|
@ -229,9 +229,7 @@ class DDPM(pl.LightningModule):
|
|||
)
|
||||
|
||||
if self.parameterization == "eps":
|
||||
lvlb_weights = self.betas**2 / (
|
||||
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)
|
||||
)
|
||||
lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
||||
elif self.parameterization == "x0":
|
||||
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
||||
elif self.parameterization == "v":
|
||||
|
@ -1186,9 +1184,11 @@ class LatentDiffusion(DDPM):
|
|||
if cond is not None:
|
||||
if isinstance(cond, dict):
|
||||
cond = {
|
||||
key: cond[key][:batch_size]
|
||||
if not isinstance(cond[key], list)
|
||||
else list(map(lambda x: x[:batch_size], cond[key]))
|
||||
key: (
|
||||
cond[key][:batch_size]
|
||||
if not isinstance(cond[key], list)
|
||||
else list(map(lambda x: x[:batch_size], cond[key]))
|
||||
)
|
||||
for key in cond
|
||||
}
|
||||
else:
|
||||
|
@ -1321,9 +1321,11 @@ class LatentDiffusion(DDPM):
|
|||
if cond is not None:
|
||||
if isinstance(cond, dict):
|
||||
cond = {
|
||||
key: cond[key][:batch_size]
|
||||
if not isinstance(cond[key], list)
|
||||
else list(map(lambda x: x[:batch_size], cond[key]))
|
||||
key: (
|
||||
cond[key][:batch_size]
|
||||
if not isinstance(cond[key], list)
|
||||
else list(map(lambda x: x[:batch_size], cond[key]))
|
||||
)
|
||||
for key in cond
|
||||
}
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
|
||||
from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
||||
|
|
|
@ -640,23 +640,25 @@ class UNetModel(nn.Module):
|
|||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer( # always uses a self-attn
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint,
|
||||
(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer( # always uses a self-attn
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Utils for monoDepth."""
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
|
|
@ -369,9 +369,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
|
|||
}
|
||||
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
|
@ -606,9 +606,9 @@ py::array build_blocks_mapping_impl(
|
|||
num_sent = 0;
|
||||
}
|
||||
} // for (auto sent_index=sent_index_first; ...
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
} // if (num_remain_sent > 1) {
|
||||
} // for (int doc=0; doc < num_docs; ++doc) {
|
||||
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
|
||||
|
||||
if (!second) {
|
||||
if (verbose) {
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
#include <cmath>
|
||||
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
|
||||
#define TILE (128 * 1024 * 1024)
|
||||
|
||||
#if defined(__aarch64__)
|
||||
|
|
|
@ -32,7 +32,7 @@ SOFTWARE
|
|||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
|
||||
#define TILE (128 * 1024 * 1024)
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
|
|
|
@ -34,14 +34,14 @@ def swin_s():
|
|||
|
||||
|
||||
# special output transform fn
|
||||
google_net_output_transform_fn = (
|
||||
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
|
||||
google_net_output_transform_fn = lambda x: (
|
||||
dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
|
||||
)
|
||||
swin_s_output_output_transform_fn = (
|
||||
lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
|
||||
swin_s_output_output_transform_fn = lambda x: (
|
||||
{f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
|
||||
)
|
||||
inception_v3_output_transform_fn = (
|
||||
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
|
||||
inception_v3_output_transform_fn = lambda x: (
|
||||
dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
|
||||
)
|
||||
|
||||
model_zoo.register(
|
||||
|
|
Loading…
Reference in New Issue