diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9871e1184..9088d0e1b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,34 +1,34 @@ repos: - repo: https://github.com/PyCQA/autoflake - rev: v2.2.1 + rev: v2.3.1 hooks: - id: autoflake name: autoflake (python) args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort name: sort all imports (python) - repo: https://github.com/psf/black-pre-commit-mirror - rev: 23.9.1 + rev: 24.4.2 hooks: - id: black name: black formatter args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v13.0.1 + rev: v18.1.8 hooks: - id: clang-format name: clang formatter types_or: [c++, c] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.6.0 hooks: - id: check-yaml - id: check-merge-conflict diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index cea1b2dbb..a0cd17bb4 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -83,15 +83,19 @@ class DataCollatorForSupervisedDataset(object): # `List[torch.Tensor]` batch_input_ids = [ - torch.LongTensor(instance["input_ids"][: self.max_length]) - if len(instance["input_ids"]) > self.max_length - else torch.LongTensor(instance["input_ids"]) + ( + torch.LongTensor(instance["input_ids"][: self.max_length]) + if len(instance["input_ids"]) > self.max_length + else torch.LongTensor(instance["input_ids"]) + ) for instance in instances ] batch_labels = [ - torch.LongTensor(instance["labels"][: self.max_length]) - if len(instance["labels"]) > self.max_length - else torch.LongTensor(instance["labels"]) + ( + torch.LongTensor(instance["labels"][: self.max_length]) + if len(instance["labels"]) > self.max_length + else torch.LongTensor(instance["labels"]) + ) for instance in instances ] if self.tokenizer.padding_side == "right": diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py index aaef447a4..e411dded1 100755 --- a/applications/ColossalChat/coati/models/loss.py +++ b/applications/ColossalChat/coati/models/loss.py @@ -1,6 +1,7 @@ """ loss functions """ + from typing import Optional, Tuple import torch diff --git a/applications/ColossalChat/coati/models/reward_model.py b/applications/ColossalChat/coati/models/reward_model.py index 394f3ea90..573b9d889 100755 --- a/applications/ColossalChat/coati/models/reward_model.py +++ b/applications/ColossalChat/coati/models/reward_model.py @@ -1,6 +1,7 @@ """ reward model """ + from typing import Optional import torch diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 5ce1e9ef0..3c836b4b4 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -1,6 +1,7 @@ """ Training utilities for Coati. """ + from typing import Any import torch diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py index 32f8544e9..d5f230249 100644 --- a/applications/ColossalEval/colossal_eval/dataset/agieval.py +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -78,7 +78,9 @@ def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict option_string = "ABCDEFG" count = len(line["options"]) - input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:" + input = ( + "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:" + ) all_classes = list(option_string[0:count]) @@ -150,7 +152,15 @@ def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=F ) elif dataset_name in chinese_qa_datasets: question_input = ( - "问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label) + "问题:" + + passage + + " " + + question + + "\n" + + "从以下选项中选择:" + + " ".join(options) + + "\n" + + "答案:{}".format(label) ) elif dataset_name in english_cloze_datasets: question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer) diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py index 2cf09ec4d..915f4d9b0 100644 --- a/applications/ColossalEval/colossal_eval/dataset/ceval.py +++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py @@ -57,7 +57,11 @@ ceval_subject_mapping = { "urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"], "accountant": ["Accountant", "注册会计师", "Other"], "fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"], - "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"], + "environmental_impact_assessment_engineer": [ + "Environmental Impact Assessment Engineer", + "环境影响评价工程师", + "Other", + ], "tax_accountant": ["Tax Accountant", "税务师", "Other"], "physician": ["Physician", "医师资格", "Other"], } diff --git a/applications/ColossalEval/colossal_eval/dataset/mtbench.py b/applications/ColossalEval/colossal_eval/dataset/mtbench.py index 9e74a4d82..031415567 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mtbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/mtbench.py @@ -56,9 +56,11 @@ class MTBenchDataset(BaseDataset): "instruction": question["turns"], "input": "", "output": [], - "target": [""] * turn_number - if question["question_id"] not in reference - else reference[question["question_id"]], + "target": ( + [""] * turn_number + if question["question_id"] not in reference + else reference[question["question_id"]] + ), } if category in dataset["test"]: diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index fff697e21..23c399cce 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -77,7 +77,9 @@ class HuggingFaceModel(BaseModel): self.indices_for_choices[0].append( self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1] ) - self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]) + self.indices_for_choices[1].append( + self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1] + ) def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict): """ diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py index 80dbf47de..2f9750de3 100644 --- a/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py +++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py @@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at https://github.com/langchain-ai/langchain The original code is licensed under the MIT license. """ + from __future__ import annotations import copy diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py index a2b1f81e3..8cb8ef536 100644 --- a/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py +++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py @@ -8,6 +8,7 @@ This code is based on LangChain Ai's langchain, which can be found at https://github.com/langchain-ai/langchain The original code is licensed under the MIT license. """ + import copy from typing import Any, Mapping, Optional, Protocol diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py index bf7ad0ffc..64e476438 100644 --- a/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py +++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py @@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at https://github.com/langchain-ai/langchain The original code is licensed under the MIT license. """ + import copy from typing import Any, List diff --git a/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py b/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py index 29542466f..0ad66f0ad 100644 --- a/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py +++ b/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py @@ -2,7 +2,6 @@ Class for loading table type data. please refer to Pandas-Input/Output for file format details. """ - import glob import os diff --git a/applications/ColossalQA/colossalqa/local/llm.py b/applications/ColossalQA/colossalqa/local/llm.py index 30a456c3d..58a4811d9 100644 --- a/applications/ColossalQA/colossalqa/local/llm.py +++ b/applications/ColossalQA/colossalqa/local/llm.py @@ -12,6 +12,7 @@ TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料 logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True) """ + from typing import Any, List, Mapping, Optional import torch diff --git a/applications/ColossalQA/colossalqa/local/utils.py b/applications/ColossalQA/colossalqa/local/utils.py index ed90264ca..2cbd474bd 100644 --- a/applications/ColossalQA/colossalqa/local/utils.py +++ b/applications/ColossalQA/colossalqa/local/utils.py @@ -1,6 +1,7 @@ """ Generation utilities """ + import json from typing import List diff --git a/applications/ColossalQA/colossalqa/memory.py b/applications/ColossalQA/colossalqa/memory.py index 7a5512281..d8de544a5 100644 --- a/applications/ColossalQA/colossalqa/memory.py +++ b/applications/ColossalQA/colossalqa/memory.py @@ -2,6 +2,7 @@ Implement a memory class for storing conversation history Support long term and short term memory """ + from typing import Any, Dict, List from colossalqa.chain.memory.summary import ConversationSummaryMemory diff --git a/applications/ColossalQA/colossalqa/mylogging.py b/applications/ColossalQA/colossalqa/mylogging.py index 574c33b41..67e2a83ed 100644 --- a/applications/ColossalQA/colossalqa/mylogging.py +++ b/applications/ColossalQA/colossalqa/mylogging.py @@ -1,6 +1,7 @@ """ Class for logging with extra control for debugging """ + import logging diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_en.py b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py index 96bce82b9..cab168075 100644 --- a/applications/ColossalQA/colossalqa/retrieval_conversation_en.py +++ b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py @@ -1,6 +1,7 @@ """ Script for Chinese retrieval based conversation system backed by ChatGLM """ + from typing import Tuple from colossalqa.chain.retrieval_qa.base import RetrievalQA diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py index 6e77bb2ae..a991b202e 100644 --- a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py +++ b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py @@ -1,6 +1,7 @@ """ Multilingual retrieval based conversation system """ + from typing import List from colossalqa.data_loader.document_loader import DocumentLoader diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py index 4eef41947..6c9b69117 100644 --- a/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py +++ b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py @@ -1,6 +1,7 @@ """ Script for Chinese retrieval based conversation system backed by ChatGLM """ + from typing import Tuple from colossalqa.chain.retrieval_qa.base import RetrievalQA diff --git a/applications/ColossalQA/colossalqa/retriever.py b/applications/ColossalQA/colossalqa/retriever.py index 6a0c69859..ec4941ddd 100644 --- a/applications/ColossalQA/colossalqa/retriever.py +++ b/applications/ColossalQA/colossalqa/retriever.py @@ -1,6 +1,7 @@ """ Code for custom retriver with incremental update """ + import copy import hashlib import os diff --git a/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py b/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py index 3815f5ed2..697af484b 100644 --- a/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py +++ b/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py @@ -1,6 +1,7 @@ """ Code for Chinese text splitter """ + from typing import Any, List, Optional from colossalqa.text_splitter.utils import get_cleaned_paragraph diff --git a/applications/ColossalQA/examples/retrieval_conversation_en.py b/applications/ColossalQA/examples/retrieval_conversation_en.py index fe2b9b4db..b7339de93 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_en.py +++ b/applications/ColossalQA/examples/retrieval_conversation_en.py @@ -1,6 +1,7 @@ """ Script for English retrieval based conversation system backed by LLaMa2 """ + import argparse import os diff --git a/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py index d4ba73b94..a0c90e7c5 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py +++ b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py @@ -1,6 +1,7 @@ """ Script for English retrieval based conversation system backed by LLaMa2 """ + import argparse import json import os diff --git a/applications/ColossalQA/examples/retrieval_conversation_zh.py b/applications/ColossalQA/examples/retrieval_conversation_zh.py index b143b9baa..96641edf5 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_zh.py +++ b/applications/ColossalQA/examples/retrieval_conversation_zh.py @@ -1,6 +1,7 @@ """ Script for Chinese retrieval based conversation system backed by ChatGLM """ + import argparse import os diff --git a/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py b/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py index adb654494..865ade5bb 100644 --- a/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py +++ b/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py @@ -1,6 +1,7 @@ """ Script for English retrieval based conversation system backed by LLaMa2 """ + import argparse import os diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py index 2f630995c..b1e32e885 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -107,20 +107,22 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward fwd_memory_cost = MemoryCost( activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) - if has_bias - else compute_size_in_bytes(weight_tensor), + parameter=( + compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor) + ), temp=0, buffer=0, ) bwd_memory_cost = MemoryCost( - activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) - if has_bias - else compute_size_in_bytes([input_tensor, weight_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) - if has_bias - else compute_size_in_bytes(weight_tensor), + activation=( + compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) + if has_bias + else compute_size_in_bytes([input_tensor, weight_tensor]) + ), + parameter=( + compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor) + ), temp=0, buffer=0, ) diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py index 62f3708fc..7e0e6ffdd 100644 --- a/colossalai/booster/plugin/__init__.py +++ b/colossalai/booster/plugin/__init__.py @@ -1,10 +1,18 @@ from .gemini_plugin import GeminiPlugin from .hybrid_parallel_plugin import HybridParallelPlugin from .low_level_zero_plugin import LowLevelZeroPlugin +from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin -__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] +__all__ = [ + "Plugin", + "TorchDDPPlugin", + "GeminiPlugin", + "LowLevelZeroPlugin", + "HybridParallelPlugin", + "MoeHybridParallelPlugin", +] import torch from packaging import version diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index 88bde3a3b..581d114d2 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -247,16 +247,16 @@ class BatchBucket: self._sequences_dict[seq.request_id] = seq self._sequences_indexes[seq.request_id] = self._current_batch_size + i # TODO external (rename): modify Sequence.sentence_len to seq_len - self._sequence_lengths[ - self._current_batch_size : self._current_batch_size + num_seqs_to_add - ] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) + self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = ( + torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) + ) # NOTE block tables to be updated by kvcache manager block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] if alloc_block_tables is not None: # copy block ids from provided block tables - self._block_tables[ - self._current_batch_size : self._current_batch_size + num_seqs_to_add - ] = alloc_block_tables + self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = ( + alloc_block_tables + ) elif alloc_block_tables_fn: alloc_block_tables_fn( block_tables, diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index c73ee9df4..e114e8a61 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -1,6 +1,7 @@ """ Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. """ + import logging from abc import ABC, abstractmethod from dataclasses import dataclass, fields @@ -82,9 +83,9 @@ class InputMetaData(RPC_PARAM): dtype: torch.dtype = torch.float32 use_spec_dec: bool = False num_tokens_to_verify: int = 0 - batch_token_ids: Optional[ - List[List[int]] - ] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process + batch_token_ids: Optional[List[List[int]]] = ( + None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process + ) def to_rpc_param(self) -> Dict[str, any]: return { @@ -202,9 +203,9 @@ class InferenceConfig(RPC_PARAM): prompt_template: Optional[str] = None do_sample: bool = False beam_width: int = 1 # TODO: beam search is not support for now - prefill_ratio: Optional[ - float - ] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio + prefill_ratio: Optional[float] = ( + 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio + ) pad_input: bool = False early_stopping: Optional[bool] = False top_k: Optional[int] = 50 @@ -234,7 +235,9 @@ class InferenceConfig(RPC_PARAM): high_precision: Optional[bool] = False # cuda_graph - use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference + use_cuda_graph: bool = ( + False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference + ) max_context_len_to_capture: int = 512 # StreamingLLM (sliding window attention with attention sinks) diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index f0918c88c..8f8aef65e 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -47,7 +47,6 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] class InferenceEngine: - """ InferenceEngine which manages the inference process.. diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py index 87222a744..749360872 100644 --- a/colossalai/inference/core/rpc_engine.py +++ b/colossalai/inference/core/rpc_engine.py @@ -34,7 +34,6 @@ def run_server(host, port, event: mp.Event = None): class RPCInferenceEngine(InferenceEngine): - """ InferenceEngine which manages the inference process.. diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py index a5199cb74..a4fd20a69 100644 --- a/colossalai/inference/executor/rpc_worker.py +++ b/colossalai/inference/executor/rpc_worker.py @@ -42,7 +42,6 @@ logger = get_dist_logger(__name__) class rpcWorkerService(rpyc.Service): - """ Execute the computation tasks and manage its own kv cache diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 378eb2ff9..dac5037f4 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -279,9 +279,11 @@ class KVCacheManager: block.add_ref() self._allocate_on_block( block, - block.block_size - if context_lengths[i] % block.block_size == 0 - else context_lengths[i].item() % block.block_size, + ( + block.block_size + if context_lengths[i] % block.block_size == 0 + else context_lengths[i].item() % block.block_size + ), ) for block_id in alloc_block_ids: if block_id in alloc_block_ids[last_block_locs]: diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 8c155e6ca..332e84d37 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -1,6 +1,7 @@ """ Utils for model inference """ + import math import os import re diff --git a/colossalai/legacy/context/process_group_initializer/initializer_2d.py b/colossalai/legacy/context/process_group_initializer/initializer_2d.py index 1c08d4d42..fc51844b6 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_2d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_2d.py @@ -138,9 +138,7 @@ class Initializer_2D(ProcessGroupInitializer): self.num_group = self.world_size // self.tensor_parallel_size self.summa_dim = int(math.sqrt(self.tensor_parallel_size)) - assert ( - self.tensor_parallel_size == self.summa_dim**2 - ), "2D summa dim should equal to tensor parallel size ^ 0.5" + assert self.tensor_parallel_size == self.summa_dim**2, "2D summa dim should equal to tensor parallel size ^ 0.5" _check_summa_env_var(self.summa_dim) self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs) diff --git a/colossalai/legacy/inference/async_engine.py b/colossalai/legacy/inference/async_engine.py index d0890ba3e..b4c523669 100644 --- a/colossalai/legacy/inference/async_engine.py +++ b/colossalai/legacy/inference/async_engine.py @@ -54,7 +54,6 @@ class RequestTracker: class Async_Engine: - """ Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager Background loop: inference reqs in waiting list (Listen) diff --git a/colossalai/legacy/inference/dynamic_batching/io_struct.py b/colossalai/legacy/inference/dynamic_batching/io_struct.py index fc5ecfe57..abc41cc8e 100644 --- a/colossalai/legacy/inference/dynamic_batching/io_struct.py +++ b/colossalai/legacy/inference/dynamic_batching/io_struct.py @@ -118,16 +118,16 @@ class Batch: class BatchTokenIdOut: def __init__(self): - self.reqs_infs: List[ - Tuple[str, int, Dict, bool, bool] - ] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] + self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = ( + [] + ) # [req_id, new_token_id, gen_metadata, finished_state, abort_state] class BatchStrOut: def __init__(self): - self.reqs_infs: List[ - Tuple[str, str, Dict, bool, bool] - ] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] + self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = ( + [] + ) # [req_id, token_str, gen_metadata, finished_state, abort_state] class AbortReq: diff --git a/colossalai/legacy/inference/hybridengine/modeling/_utils.py b/colossalai/legacy/inference/hybridengine/modeling/_utils.py index 068b64b4f..46d4222c4 100644 --- a/colossalai/legacy/inference/hybridengine/modeling/_utils.py +++ b/colossalai/legacy/inference/hybridengine/modeling/_utils.py @@ -1,6 +1,7 @@ """ Utils for model inference """ + import os import torch diff --git a/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py b/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py index f707a86df..b72610899 100644 --- a/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py @@ -14,6 +14,7 @@ class BatchInferState: Information to be passed and used for a batch of inputs during a single model forward """ + batch_size: int max_len_in_batch: int diff --git a/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py b/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py index 91bb96a1f..8c54fda26 100644 --- a/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py @@ -4,6 +4,7 @@ of the ModelTC/lightllm GitHub repository https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design. """ + import torch from transformers.utils import logging diff --git a/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py b/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py index 068b64b4f..46d4222c4 100644 --- a/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py +++ b/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py @@ -1,6 +1,7 @@ """ Utils for model inference """ + import os import torch diff --git a/colossalai/quantization/bnb.py b/colossalai/quantization/bnb.py index fa214116a..3601ef62b 100644 --- a/colossalai/quantization/bnb.py +++ b/colossalai/quantization/bnb.py @@ -1,17 +1,25 @@ # adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py +import importlib.metadata import logging import torch import torch.nn as nn +from packaging.version import Version from .bnb_config import BnbQuantizationConfig try: import bitsandbytes as bnb - IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0" - IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2" + try: + # in case lower version of bitsandbytes does not have __version__ attribute + BNB_VERSION = Version(bnb.__version__) + except AttributeError: + BNB_VERSION = Version(importlib.metadata.version("bitsandbytes")) + + IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0") + IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2") except ImportError: pass diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index bf581300a..6ae4b06e5 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -33,6 +33,7 @@ This license shall be governed and construed in accordance with the laws of Peop Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. """ + """ PyTorch ChatGLM model. """ import copy diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index f10860fef..b250b4976 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -221,7 +221,7 @@ class OPTPipelineForwards: past_key_value = past_key_values[idx] if past_key_values is not None else None if decoder.gradient_checkpointing and decoder.training: - layer_outputs = self._gradient_checkpointing_func( + layer_outputs = self.decoder._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_attention_mask, diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index e0aa5fba4..11c26822f 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -168,13 +168,27 @@ class Qwen2PipelineForwards: next_decoder_cache = None start_idx, end_idx = stage_index[0], stage_index[1] + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_stages=stage_manager.num_stages, + num_layers=end_idx - start_idx, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + num_model_chunks=stage_manager.num_model_chunks, + ) + assert num_ckpt_layers <= end_idx - start_idx + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: + if idx - start_idx < num_ckpt_layers: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, @@ -198,7 +212,6 @@ class Qwen2PipelineForwards: if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index f9721c79e..0fb858d78 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -40,21 +40,19 @@ class MixtralPolicy(Policy): if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") - if getattr(self.shard_config, "ep_group", None) is None: - raise ValueError("You must pass in ep_group via shard_config for expert parallel!") - - # expert parallel - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="block_sparse_moe", - target_module=EPMixtralSparseMoeBlock, - kwargs={"ep_group": self.shard_config.ep_group}, - ) - ], - policy=policy, - target_key=MixtralDecoderLayer, - ) + if getattr(self.shard_config, "ep_group", None) is not None: + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="block_sparse_moe", + target_module=EPMixtralSparseMoeBlock, + kwargs={"ep_group": self.shard_config.ep_group}, + ) + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) # optimization configuration if self.shard_config.enable_fused_normalization: diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e06cf0581..bdc91b51f 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param = real_working_params[group_id][idx] param_to_gather = master_param.to(device).to(self._dtype) pg = self.param_to_pg[working_param] + if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: + buffer_tensor = torch.empty_like( + torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))]) + ) + dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) + working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param)) + continue try: self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) except RuntimeError: diff --git a/docs/source/en/features/distributed_optimizers.md b/docs/source/en/features/distributed_optimizers.md index f95b23304..279bc8f9d 100644 --- a/docs/source/en/features/distributed_optimizers.md +++ b/docs/source/en/features/distributed_optimizers.md @@ -87,44 +87,42 @@ optim = DistGaloreAwamW( ## Plugin compatibility - - - - - + + + + + + - - + + + - + - + + - + - - - - - - - - - + + + + diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 68d310f5c..40b8954b5 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -55,7 +55,7 @@ Model/Feature Compatibility Matrix: - + diff --git a/docs/source/zh-Hans/features/distributed_optimizers.md b/docs/source/zh-Hans/features/distributed_optimizers.md index 7a7068077..5761f8c55 100644 --- a/docs/source/zh-Hans/features/distributed_optimizers.md +++ b/docs/source/zh-Hans/features/distributed_optimizers.md @@ -84,44 +84,42 @@ optim = DistGaloreAwamW( ## 兼容性
Model/FeatureLambGaLoreAdafactorCAMEOptimizer/PluginHybrid Parallel PluginLow Level Zero PluginTorch DDP PluginGemini PluginMoe Hybrid Plugin
Hybrid Parallel
Plugin
✔️Lamb ✔️ ✔️ ✔️
Low Level Zero
Plugin
GaLore ✔️ ✔️ ✔️
Torch DDP
Plugin
Adafactor ✔️ ✔️ ✔️✔️
Gemini
Plugin
Moe Hybrid
Plugin
CAME✔️✔️✔️
✔️ ✔️ ✔️✔️
- - - - - + + + + + + - - + + + - + - + + - - + - - - - - - - - + + + + @@ -130,6 +128,7 @@ optim = DistGaloreAwamW(
Model/FeatureLambGaLoreAdafactorCAMEOptimizer/PluginHybrid Parallel PluginLow Level Zero PluginTorch DDP PluginGemini PluginMoe Hybrid Plugin
Hybrid Parallel
Plugin
✔️Lamb ✔️ ✔️ ✔️
Low Level Zero
Plugin
GaLore ✔️ ✔️ ✔️
Torch DDP
Plugin
✔️Adafactor ✔️ ✔️ ✔️
Gemini
Plugin
Moe Hybrid
Plugin
CAME✔️✔️✔️
+ ## API 参考 diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index 00e1a13d6..02290f3d6 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -51,7 +51,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. ✔️ ✔️ ✔️ - ❌ + ✔️ ❌ diff --git a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py index 09677a619..4d08d9941 100644 --- a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py +++ b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py @@ -52,9 +52,11 @@ class pretraining_dataset(Dataset): def __getitem__(self, index): [input_ids, input_mask, segment_ids, masked_lm_labels] = [ - torch.from_numpy(input[index].astype(np.int64)) - if indice < 5 - else torch.from_numpy(np.asarray(input[index].astype(np.int64))) + ( + torch.from_numpy(input[index].astype(np.int64)) + if indice < 5 + else torch.from_numpy(np.asarray(input[index].astype(np.int64))) + ) for indice, input in enumerate(self.inputs) ] diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py index 20e26256e..3cf6aceb5 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddpm.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -229,9 +229,7 @@ class DDPM(pl.LightningModule): ) if self.parameterization == "eps": - lvlb_weights = self.betas**2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) - ) + lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) elif self.parameterization == "x0": lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) elif self.parameterization == "v": @@ -1186,9 +1184,11 @@ class LatentDiffusion(DDPM): if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) + key: ( + cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + ) for key in cond } else: @@ -1321,9 +1321,11 @@ class LatentDiffusion(DDPM): if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) + key: ( + cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + ) for key in cond } else: diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py index 55dac8555..4104fe3b0 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py @@ -1,4 +1,5 @@ """SAMPLING ONLY.""" + import torch from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py index 6c80f3229..afde5dfd4 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -640,23 +640,25 @@ class UNetModel(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( # always uses a self-attn - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, - use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, + ( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + ) ), ResBlock( ch, diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py index 0dd87b596..8c13f39ff 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py @@ -2,6 +2,7 @@ This file contains code that is adapted from https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py """ + import torch import torch.nn as nn diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py index 4d30744c4..c79581afc 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py @@ -2,6 +2,7 @@ This file contains code that is adapted from https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py """ + import torch import torch.nn as nn diff --git a/examples/images/diffusion/ldm/modules/midas/utils.py b/examples/images/diffusion/ldm/modules/midas/utils.py index 1428d42b2..f7fc7dcc9 100644 --- a/examples/images/diffusion/ldm/modules/midas/utils.py +++ b/examples/images/diffusion/ldm/modules/midas/utils.py @@ -1,4 +1,5 @@ """Utils for monoDepth.""" + import re import sys diff --git a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp index 52977e631..fe9968177 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp +++ b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp @@ -369,9 +369,9 @@ py::array build_mapping_impl(const py::array_t& docs_, } } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { if (!second) { if (verbose) { @@ -606,9 +606,9 @@ py::array build_blocks_mapping_impl( num_sent = 0; } } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { if (!second) { if (verbose) { diff --git a/extensions/csrc/kernel/arm/cpu_adam_arm.h b/extensions/csrc/kernel/arm/cpu_adam_arm.h index c731850ed..d48968e21 100644 --- a/extensions/csrc/kernel/arm/cpu_adam_arm.h +++ b/extensions/csrc/kernel/arm/cpu_adam_arm.h @@ -4,7 +4,7 @@ #include -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #define TILE (128 * 1024 * 1024) #if defined(__aarch64__) diff --git a/extensions/csrc/kernel/x86/cpu_adam.h b/extensions/csrc/kernel/x86/cpu_adam.h index db1f26d5f..45e1dde62 100644 --- a/extensions/csrc/kernel/x86/cpu_adam.h +++ b/extensions/csrc/kernel/x86/cpu_adam.h @@ -32,7 +32,7 @@ SOFTWARE #include #endif -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #define TILE (128 * 1024 * 1024) #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) diff --git a/tests/kit/model_zoo/torchvision/torchvision.py b/tests/kit/model_zoo/torchvision/torchvision.py index 57b633e9d..c0524d089 100644 --- a/tests/kit/model_zoo/torchvision/torchvision.py +++ b/tests/kit/model_zoo/torchvision/torchvision.py @@ -34,14 +34,14 @@ def swin_s(): # special output transform fn -google_net_output_transform_fn = ( - lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x) +google_net_output_transform_fn = lambda x: ( + dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x) ) -swin_s_output_output_transform_fn = ( - lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) +swin_s_output_output_transform_fn = lambda x: ( + {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) ) -inception_v3_output_transform_fn = ( - lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x) +inception_v3_output_transform_fn = lambda x: ( + dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x) ) model_zoo.register(