Merge branch 'hpcaitech:main' into feature/fp8_comm

pull/5885/head
Hanks 2024-07-04 20:34:41 +08:00 committed by GitHub
commit 6991819a97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
63 changed files with 265 additions and 177 deletions

View File

@ -1,34 +1,34 @@
repos:
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.1
rev: v2.3.1
hooks:
- id: autoflake
name: autoflake (python)
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
name: sort all imports (python)
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.9.1
rev: 24.4.2
hooks:
- id: black
name: black formatter
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.1
rev: v18.1.8
hooks:
- id: clang-format
name: clang formatter
types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.6.0
hooks:
- id: check-yaml
- id: check-merge-conflict

View File

@ -83,15 +83,19 @@ class DataCollatorForSupervisedDataset(object):
# `List[torch.Tensor]`
batch_input_ids = [
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
(
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
)
for instance in instances
]
batch_labels = [
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
(
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
)
for instance in instances
]
if self.tokenizer.padding_side == "right":

View File

@ -1,6 +1,7 @@
"""
loss functions
"""
from typing import Optional, Tuple
import torch

View File

@ -1,6 +1,7 @@
"""
reward model
"""
from typing import Optional
import torch

View File

@ -1,6 +1,7 @@
"""
Training utilities for Coati.
"""
from typing import Any
import torch

View File

@ -78,7 +78,9 @@ def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict
option_string = "ABCDEFG"
count = len(line["options"])
input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
input = (
"问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
)
all_classes = list(option_string[0:count])
@ -150,7 +152,15 @@ def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=F
)
elif dataset_name in chinese_qa_datasets:
question_input = (
"问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label)
"问题:"
+ passage
+ " "
+ question
+ "\n"
+ "从以下选项中选择:"
+ " ".join(options)
+ "\n"
+ "答案:{}".format(label)
)
elif dataset_name in english_cloze_datasets:
question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)

View File

@ -57,7 +57,11 @@ ceval_subject_mapping = {
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
"accountant": ["Accountant", "注册会计师", "Other"],
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
"environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"],
"environmental_impact_assessment_engineer": [
"Environmental Impact Assessment Engineer",
"环境影响评价工程师",
"Other",
],
"tax_accountant": ["Tax Accountant", "税务师", "Other"],
"physician": ["Physician", "医师资格", "Other"],
}

View File

@ -56,9 +56,11 @@ class MTBenchDataset(BaseDataset):
"instruction": question["turns"],
"input": "",
"output": [],
"target": [""] * turn_number
if question["question_id"] not in reference
else reference[question["question_id"]],
"target": (
[""] * turn_number
if question["question_id"] not in reference
else reference[question["question_id"]]
),
}
if category in dataset["test"]:

View File

@ -77,7 +77,9 @@ class HuggingFaceModel(BaseModel):
self.indices_for_choices[0].append(
self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
)
self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1])
self.indices_for_choices[1].append(
self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]
)
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
"""

View File

@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
from __future__ import annotations
import copy

View File

@ -8,6 +8,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
import copy
from typing import Any, Mapping, Optional, Protocol

View File

@ -7,6 +7,7 @@ This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
import copy
from typing import Any, List

View File

@ -2,7 +2,6 @@
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
"""
import glob
import os

View File

@ -12,6 +12,7 @@ TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
"""
from typing import Any, List, Mapping, Optional
import torch

View File

@ -1,6 +1,7 @@
"""
Generation utilities
"""
import json
from typing import List

View File

@ -2,6 +2,7 @@
Implement a memory class for storing conversation history
Support long term and short term memory
"""
from typing import Any, Dict, List
from colossalqa.chain.memory.summary import ConversationSummaryMemory

View File

@ -1,6 +1,7 @@
"""
Class for logging with extra control for debugging
"""
import logging

View File

@ -1,6 +1,7 @@
"""
Script for Chinese retrieval based conversation system backed by ChatGLM
"""
from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA

View File

@ -1,6 +1,7 @@
"""
Multilingual retrieval based conversation system
"""
from typing import List
from colossalqa.data_loader.document_loader import DocumentLoader

View File

@ -1,6 +1,7 @@
"""
Script for Chinese retrieval based conversation system backed by ChatGLM
"""
from typing import Tuple
from colossalqa.chain.retrieval_qa.base import RetrievalQA

View File

@ -1,6 +1,7 @@
"""
Code for custom retriver with incremental update
"""
import copy
import hashlib
import os

View File

@ -1,6 +1,7 @@
"""
Code for Chinese text splitter
"""
from typing import Any, List, Optional
from colossalqa.text_splitter.utils import get_cleaned_paragraph

View File

@ -1,6 +1,7 @@
"""
Script for English retrieval based conversation system backed by LLaMa2
"""
import argparse
import os

View File

@ -1,6 +1,7 @@
"""
Script for English retrieval based conversation system backed by LLaMa2
"""
import argparse
import json
import os

View File

@ -1,6 +1,7 @@
"""
Script for Chinese retrieval based conversation system backed by ChatGLM
"""
import argparse
import os

View File

@ -1,6 +1,7 @@
"""
Script for English retrieval based conversation system backed by LLaMa2
"""
import argparse
import os

View File

@ -107,20 +107,22 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes(weight_tensor),
parameter=(
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
),
temp=0,
buffer=0,
)
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes(weight_tensor),
activation=(
compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes([input_tensor, weight_tensor])
),
parameter=(
compute_size_in_bytes([weight_tensor, bias_tensor]) if has_bias else compute_size_in_bytes(weight_tensor)
),
temp=0,
buffer=0,
)

View File

@ -1,10 +1,18 @@
from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
__all__ = [
"Plugin",
"TorchDDPPlugin",
"GeminiPlugin",
"LowLevelZeroPlugin",
"HybridParallelPlugin",
"MoeHybridParallelPlugin",
]
import torch
from packaging import version

View File

@ -247,16 +247,16 @@ class BatchBucket:
self._sequences_dict[seq.request_id] = seq
self._sequences_indexes[seq.request_id] = self._current_batch_size + i
# TODO external (rename): modify Sequence.sentence_len to seq_len
self._sequence_lengths[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32)
)
# NOTE block tables to be updated by kvcache manager
block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add]
if alloc_block_tables is not None:
# copy block ids from provided block tables
self._block_tables[
self._current_batch_size : self._current_batch_size + num_seqs_to_add
] = alloc_block_tables
self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] = (
alloc_block_tables
)
elif alloc_block_tables_fn:
alloc_block_tables_fn(
block_tables,

View File

@ -1,6 +1,7 @@
"""
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
@ -82,9 +83,9 @@ class InputMetaData(RPC_PARAM):
dtype: torch.dtype = torch.float32
use_spec_dec: bool = False
num_tokens_to_verify: int = 0
batch_token_ids: Optional[
List[List[int]]
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
batch_token_ids: Optional[List[List[int]]] = (
None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
)
def to_rpc_param(self) -> Dict[str, any]:
return {
@ -202,9 +203,9 @@ class InferenceConfig(RPC_PARAM):
prompt_template: Optional[str] = None
do_sample: bool = False
beam_width: int = 1 # TODO: beam search is not support for now
prefill_ratio: Optional[
float
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
prefill_ratio: Optional[float] = (
1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
)
pad_input: bool = False
early_stopping: Optional[bool] = False
top_k: Optional[int] = 50
@ -234,7 +235,9 @@ class InferenceConfig(RPC_PARAM):
high_precision: Optional[bool] = False
# cuda_graph
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph: bool = (
False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
)
max_context_len_to_capture: int = 512
# StreamingLLM (sliding window attention with attention sinks)

View File

@ -47,7 +47,6 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class InferenceEngine:
"""
InferenceEngine which manages the inference process..

View File

@ -34,7 +34,6 @@ def run_server(host, port, event: mp.Event = None):
class RPCInferenceEngine(InferenceEngine):
"""
InferenceEngine which manages the inference process..

View File

@ -42,7 +42,6 @@ logger = get_dist_logger(__name__)
class rpcWorkerService(rpyc.Service):
"""
Execute the computation tasks and manage its own kv cache

View File

@ -279,9 +279,11 @@ class KVCacheManager:
block.add_ref()
self._allocate_on_block(
block,
block.block_size
if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size,
(
block.block_size
if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size
),
)
for block_id in alloc_block_ids:
if block_id in alloc_block_ids[last_block_locs]:

View File

@ -1,6 +1,7 @@
"""
Utils for model inference
"""
import math
import os
import re

View File

@ -138,9 +138,7 @@ class Initializer_2D(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size
self.summa_dim = int(math.sqrt(self.tensor_parallel_size))
assert (
self.tensor_parallel_size == self.summa_dim**2
), "2D summa dim should equal to tensor parallel size ^ 0.5"
assert self.tensor_parallel_size == self.summa_dim**2, "2D summa dim should equal to tensor parallel size ^ 0.5"
_check_summa_env_var(self.summa_dim)
self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs)

View File

@ -54,7 +54,6 @@ class RequestTracker:
class Async_Engine:
"""
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
Background loop: inference reqs in waiting list (Listen)

View File

@ -118,16 +118,16 @@ class Batch:
class BatchTokenIdOut:
def __init__(self):
self.reqs_infs: List[
Tuple[str, int, Dict, bool, bool]
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = (
[]
) # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
class BatchStrOut:
def __init__(self):
self.reqs_infs: List[
Tuple[str, str, Dict, bool, bool]
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = (
[]
) # [req_id, token_str, gen_metadata, finished_state, abort_state]
class AbortReq:

View File

@ -1,6 +1,7 @@
"""
Utils for model inference
"""
import os
import torch

View File

@ -14,6 +14,7 @@ class BatchInferState:
Information to be passed and used for a batch of inputs during
a single model forward
"""
batch_size: int
max_len_in_batch: int

View File

@ -4,6 +4,7 @@ of the ModelTC/lightllm GitHub repository
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
"""
import torch
from transformers.utils import logging

View File

@ -1,6 +1,7 @@
"""
Utils for model inference
"""
import os
import torch

View File

@ -1,17 +1,25 @@
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
import importlib.metadata
import logging
import torch
import torch.nn as nn
from packaging.version import Version
from .bnb_config import BnbQuantizationConfig
try:
import bitsandbytes as bnb
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
try:
# in case lower version of bitsandbytes does not have __version__ attribute
BNB_VERSION = Version(bnb.__version__)
except AttributeError:
BNB_VERSION = Version(importlib.metadata.version("bitsandbytes"))
IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0")
IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2")
except ImportError:
pass

View File

@ -33,6 +33,7 @@ This license shall be governed and construed in accordance with the laws of Peop
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
"""
""" PyTorch ChatGLM model. """
import copy

View File

@ -221,7 +221,7 @@ class OPTPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training:
layer_outputs = self._gradient_checkpointing_func(
layer_outputs = self.decoder._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_attention_mask,

View File

@ -168,13 +168,27 @@ class Qwen2PipelineForwards:
next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
@ -198,7 +212,6 @@ class Qwen2PipelineForwards:
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)

View File

@ -40,21 +40,19 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
if getattr(self.shard_config, "ep_group", None) is None:
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
if getattr(self.shard_config, "ep_group", None) is not None:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:

View File

@ -549,6 +549,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param]
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:

View File

@ -87,44 +87,42 @@ optim = DistGaloreAwamW(
## Plugin compatibility
<table>
<tr>
<th nowrap="nowrap">Model/Feature</th>
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th>
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th>
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th>
<th nowrap="nowrap" align="center" title="CAME">CAME</th>
<th nowrap="nowrap">Optimizer/Plugin</th>
<th nowrap="nowrap" align="center">Hybrid Parallel Plugin</th>
<th nowrap="nowrap" align="center">Low Level Zero Plugin</th>
<th nowrap="nowrap" align="center">Torch DDP Plugin</th>
<th nowrap="nowrap" align="center">Gemini Plugin</th>
<th nowrap="nowrap" align="center">Moe Hybrid Plugin</th>
</tr>
<tr>
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td>
<td nowrap="nowrap" align="center" title="Lamb">Lamb</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Gemini<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td>
<td nowrap="nowrap" align="center" title="GaLore">GaLore</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="Adafactor">Adafactor</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="CAME">CAME</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>

View File

@ -55,7 +55,7 @@ Model/Feature Compatibility Matrix:
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>

View File

@ -84,44 +84,42 @@ optim = DistGaloreAwamW(
## 兼容性
<table>
<tr>
<th nowrap="nowrap">Model/Feature</th>
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th>
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th>
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th>
<th nowrap="nowrap" align="center" title="CAME">CAME</th>
<th nowrap="nowrap">Optimizer/Plugin</th>
<th nowrap="nowrap" align="center">Hybrid Parallel Plugin</th>
<th nowrap="nowrap" align="center">Low Level Zero Plugin</th>
<th nowrap="nowrap" align="center">Torch DDP Plugin</th>
<th nowrap="nowrap" align="center">Gemini Plugin</th>
<th nowrap="nowrap" align="center">Moe Hybrid Plugin</th>
</tr>
<tr>
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td>
<td nowrap="nowrap" align="center" title="Lamb">Lamb</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Gemini<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td>
<td nowrap="nowrap" align="center" title="GaLore">GaLore</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="Adafactor">Adafactor</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="CAME">CAME</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
@ -130,6 +128,7 @@ optim = DistGaloreAwamW(
</tr>
</table>
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->
## API 参考

View File

@ -51,7 +51,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>

View File

@ -52,9 +52,11 @@ class pretraining_dataset(Dataset):
def __getitem__(self, index):
[input_ids, input_mask, segment_ids, masked_lm_labels] = [
torch.from_numpy(input[index].astype(np.int64))
if indice < 5
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
(
torch.from_numpy(input[index].astype(np.int64))
if indice < 5
else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
)
for indice, input in enumerate(self.inputs)
]

View File

@ -229,9 +229,7 @@ class DDPM(pl.LightningModule):
)
if self.parameterization == "eps":
lvlb_weights = self.betas**2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)
)
lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
elif self.parameterization == "v":
@ -1186,9 +1184,11 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {
key: cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
key: (
cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
)
for key in cond
}
else:
@ -1321,9 +1321,11 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {
key: cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
key: (
cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
)
for key in cond
}
else:

View File

@ -1,4 +1,5 @@
"""SAMPLING ONLY."""
import torch
from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper

View File

@ -640,23 +640,25 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
)
),
ResBlock(
ch,

View File

@ -2,6 +2,7 @@
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn

View File

@ -2,6 +2,7 @@
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn

View File

@ -1,4 +1,5 @@
"""Utils for monoDepth."""
import re
import sys

View File

@ -369,9 +369,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
@ -606,9 +606,9 @@ py::array build_blocks_mapping_impl(
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {

View File

@ -4,7 +4,7 @@
#include <cmath>
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#define TILE (128 * 1024 * 1024)
#if defined(__aarch64__)

View File

@ -32,7 +32,7 @@ SOFTWARE
#include <x86intrin.h>
#endif
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define ROUND_DOWN(size, step) ((size) & ~((step) - 1))
#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)

View File

@ -34,14 +34,14 @@ def swin_s():
# special output transform fn
google_net_output_transform_fn = (
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
google_net_output_transform_fn = lambda x: (
dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
)
swin_s_output_output_transform_fn = (
lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
swin_s_output_output_transform_fn = lambda x: (
{f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
)
inception_v3_output_transform_fn = (
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
inception_v3_output_transform_fn = lambda x: (
dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
)
model_zoo.register(