[devops] remove post commit ci (#5566)

* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5578/head
Hongxin Liu 8 months ago committed by GitHub
parent 341263df48
commit 641b1ee71a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,6 +3,7 @@
- [ ] I have created an issue for this PR for traceability - [ ] I have created an issue for this PR for traceability
- [ ] The title follows the standard format: `[doc/gemini/tensor/...]: A concise description` - [ ] The title follows the standard format: `[doc/gemini/tensor/...]: A concise description`
- [ ] I have added relevant tags if possible for us to better distinguish different PRs - [ ] I have added relevant tags if possible for us to better distinguish different PRs
- [ ] I have installed pre-commit: `pip install pre-commit && pre-commit install`
## 🚨 Issue number ## 🚨 Issue number

@ -1,97 +0,0 @@
name: post-commit
on:
pull_request:
types:
- closed
jobs:
# this job will run after a PR is merged to run pre-commit on any changed file
# so that the user does not need to learn pre-commit and pre-commit can still
# be auto-executed by the workflow
pre-commit:
runs-on: ubuntu-latest
if: github.event.pull_request.merged == true && github.repository == 'hpcaitech/ColossalAI'
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.sha }}
# the PR branch and the hpcaitech/colossal-ai main branch
# must share a common commit, we need to locate that commit,
# which is the commit checked-out or forked when the PR branch is created
# such that we can look for files changed since that commit
- name: Locate base commit
id: locate-base-sha
run: |
curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
- name: Find the changed files
id: find-changed-files
uses: tj-actions/changed-files@v35
with:
base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }}
- name: List all changed files
run: |
for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
echo "$file was changed"
done
# check out the main branch
- uses: actions/checkout@v2
with:
ref: 'main'
- uses: actions/setup-python@v3
- name: Cache pre-commit hooks
uses: actions/cache@v3
with:
path: ~/.cache/pre-commit
key: ${{ runner.os }}-pre-commit-hooks
- name: Set up pre-commit
run: |
pip install pre-commit
pre-commit install
# run pre-commit on changed files
- name: Run Pre-commit
run: |
for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do
pre-commit run --files $file || true
done
# create commit for pre-commit
# when all files are well formatted, there is no need to create a commit
# therefore, this step will produce an error, which should be allowed
- name: Create commits
id: commit
continue-on-error: true
run: |
git config --global user.name 'github-actions'
git config --global user.email 'github-actions@github.com'
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}
git add -A
git commit -am "[format] applied code formatting on changed files in pull request ${{ github.event.pull_request.number }}"
# create pull request
- name: Create Pull Request
if: steps.commit.outcome == 'success'
id: cpr
uses: peter-evans/create-pull-request@v4
with:
branch: pre-commit-${{ github.event.pull_request.number }}
title: "[format] applied code formatting on changed files in PR ${{ github.event.pull_request.number }}"
- name: Enable Auto-merge for the New PR
if: steps.commit.outcome == 'success'
uses: peter-evans/enable-pull-request-automerge@v2
with:
pull-request-number: ${{ steps.cpr.outputs.pull-request-number }}
merge-method: squash

@ -8,11 +8,10 @@ import argparse
import numpy as np import numpy as np
import torch import torch
from transformers import LlamaTokenizer, LlamaForCausalLM from transformers import LlamaForCausalLM, LlamaTokenizer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
logger = get_dist_logger() logger = get_dist_logger()

@ -10,8 +10,8 @@ import os
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
import torch import torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator

@ -1,20 +1,19 @@
from copy import deepcopy from copy import deepcopy
from typing import Optional, List, Dict, Tuple, Callable, Any from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from transformers.utils import logging
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def get_prompt_template( def get_prompt_template(
input_query:str, input_query: str,
history:List[Dict]= None, history: List[Dict] = None,
roles:list = ["", "Human", "Assistant"], roles: list = ["", "Human", "Assistant"],
) -> str: ) -> str:
""" """
Generates a prompt template for chat models based on input and history. Generates a prompt template for chat models based on input and history.
@ -48,6 +47,7 @@ def get_prompt_template(
prompt += f"{role}: <s>" prompt += f"{role}: <s>"
return prompt return prompt
@torch.inference_mode() @torch.inference_mode()
def streaming_chat( def streaming_chat(
model: Any, model: Any,
@ -99,14 +99,14 @@ def streaming_chat(
logits_processor = LogitsProcessorList() logits_processor = LogitsProcessorList()
generation_kwargs = { generation_kwargs = {
'temperature': temperature, "temperature": temperature,
'top_p': top_p, "top_p": top_p,
'top_k': top_k, "top_k": top_k,
'do_sample': do_sample, "do_sample": do_sample,
'max_new_tokens': max_new_tokens, "max_new_tokens": max_new_tokens,
'length_penalty': length_penalty, "length_penalty": length_penalty,
'use_cache': True, "use_cache": True,
**kwargs **kwargs,
} }
prompt_str = get_prompt_template(input_query, history=history, roles=roles) prompt_str = get_prompt_template(input_query, history=history, roles=roles)
@ -116,13 +116,18 @@ def streaming_chat(
history.append({"role": roles[1], "message": input_query.strip()}) history.append({"role": roles[1], "message": input_query.strip()})
history.append({"role": roles[2], "message": None}) history.append({"role": roles[2], "message": None})
for outputs in stream_generate(model, **inputs, past_key_values=past_key_values, for outputs in stream_generate(
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, model,
**generation_kwargs): **inputs,
past_key_values=past_key_values,
eos_token_id=eos_token_id,
return_past_key_values=return_past_key_values,
**generation_kwargs,
):
if return_past_key_values: if return_past_key_values:
outputs, past_key_values = outputs outputs, past_key_values = outputs
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
response = tokenizer.decode(outputs) response = tokenizer.decode(outputs)
history[-1]["message"] = response.strip() history[-1]["message"] = response.strip()

@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs):
model.to(device) model.to(device)
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left') tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
except OSError: except OSError:
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.") raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")

@ -12,4 +12,3 @@ flash-attn>=2.0.0,<=2.0.5
tqdm tqdm
sentencepiece==0.1.99 sentencepiece==0.1.99
protobuf<=3.20.0 protobuf<=3.20.0

@ -1,11 +1,11 @@
import os
import argparse import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
from colossal_llama2.utils.stream_chat_patch import streaming_chat from colossal_llama2.utils.stream_chat_patch import streaming_chat
from transformers import AutoModelForCausalLM, AutoTokenizer
SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
def main(args): def main(args):
model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval() model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
@ -27,29 +27,34 @@ def main(args):
print(f"\n{roles[2]}: ", end="") print(f"\n{roles[2]}: ", end="")
gen_len = 0 gen_len = 0
for response, history, past_key_values in streaming_chat( for response, history, past_key_values in streaming_chat(
model, tokenizer, input_query, history=history, roles=roles, model,
temperature = args.temperature, tokenizer,
top_p = args.top_p, input_query,
top_k = args.top_k, history=history,
do_sample = args.do_sample, roles=roles,
length_penalty = args.length_penalty, temperature=args.temperature,
max_new_tokens = args.max_new_tokens, top_p=args.top_p,
top_k=args.top_k,
do_sample=args.do_sample,
length_penalty=args.length_penalty,
max_new_tokens=args.max_new_tokens,
past_key_values=past_key_values, past_key_values=past_key_values,
return_past_key_values=True): return_past_key_values=True,
):
output = response[gen_len:] output = response[gen_len:]
print(output, end="", flush=True) print(output, end="", flush=True)
gen_len = len(response) gen_len = len(response)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default=None, help="path to chat version model") parser.add_argument("--model_path", type=str, default=None, help="path to chat version model")
parser.add_argument('--tokenizer_path', type=str, default=None, help="path to chat version tokenizer") parser.add_argument("--tokenizer_path", type=str, default=None, help="path to chat version tokenizer")
parser.add_argument('--temperature', type=float, default=0.8, help="set temperature") parser.add_argument("--temperature", type=float, default=0.8, help="set temperature")
parser.add_argument('--top_p', type=float, default=0.95, help="set top p value") parser.add_argument("--top_p", type=float, default=0.95, help="set top p value")
parser.add_argument('--top_k', type=int, default=50, help="set top k value") parser.add_argument("--top_k", type=int, default=50, help="set top k value")
parser.add_argument('--do_sample', type=bool, default=True, help="whether turn on do_sample or not") parser.add_argument("--do_sample", type=bool, default=True, help="whether turn on do_sample or not")
parser.add_argument('--length_penalty', type=float, default=1.2, help="set length penalty") parser.add_argument("--length_penalty", type=float, default=1.2, help="set length penalty")
parser.add_argument('--max_new_tokens', type=int, default=512, help="set max new tokens") parser.add_argument("--max_new_tokens", type=int, default=512, help="set max new tokens")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

@ -20,13 +20,13 @@ import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.logging import get_dist_logger
logger = get_dist_logger() logger = get_dist_logger()
def train(args): def train(args):
# check lora compatibility # check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0: if "gemini" in args.plugin and args.lora_rank > 0:

@ -3,7 +3,6 @@ import copy
import os import os
from typing import Dict, List from typing import Dict, List
import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_eval import dataset, models, utils from colossal_eval import dataset, models, utils

@ -106,6 +106,5 @@ def main():
print(f"[{coordinator.rank}] {outputs}") print(f"[{coordinator.rank}] {outputs}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

@ -24,6 +24,7 @@ from langchain.pydantic_v1 import Field
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
class CustomBaseRetrievalQA(BaseRetrievalQA): class CustomBaseRetrievalQA(BaseRetrievalQA):
"""Base class for question-answering chains.""" """Base class for question-answering chains."""
@ -98,7 +99,6 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
for k, v in inputs.items() for k, v in inputs.items()
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"] if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
} }
answers = []
if self.combine_documents_chain.memory is not None: if self.combine_documents_chain.memory is not None:
buffered_history_backup, summarized_history_temp_backup = copy.deepcopy( buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
self.combine_documents_chain.memory.buffered_history self.combine_documents_chain.memory.buffered_history
@ -117,10 +117,10 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup) ) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
# if rejection_trigger_keywords is not given, return the response from LLM directly # if rejection_trigger_keywords is not given, return the response from LLM directly
rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', []) rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) else None answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) else None
if answer is None: if answer is None:
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。") answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
if self.combine_documents_chain.memory is not None: if self.combine_documents_chain.memory is not None:
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer}) self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
@ -161,10 +161,14 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
) )
# if rejection_trigger_keywords is not given, return the response from LLM directly # if rejection_trigger_keywords is not given, return the response from LLM directly
rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', []) rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords)==0 else None answer = (
answer
if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords) == 0
else None
)
if answer is None: if answer is None:
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。") answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer}) self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
if self.return_source_documents: if self.return_source_documents:

@ -1,32 +1,33 @@
''' """
Class for loading table type data. please refer to Pandas-Input/Output for file format details. Class for loading table type data. please refer to Pandas-Input/Output for file format details.
''' """
import os
import glob import glob
import os
import pandas as pd import pandas as pd
from sqlalchemy import create_engine
from colossalqa.utils import drop_table
from colossalqa.mylogging import get_logger from colossalqa.mylogging import get_logger
from colossalqa.utils import drop_table
from sqlalchemy import create_engine
logger = get_logger() logger = get_logger()
SUPPORTED_DATA_FORMAT = ['.csv','.xlsx', '.xls','.json','.html','.h5', '.hdf5','.parquet','.feather','.dta'] SUPPORTED_DATA_FORMAT = [".csv", ".xlsx", ".xls", ".json", ".html", ".h5", ".hdf5", ".parquet", ".feather", ".dta"]
class TableLoader: class TableLoader:
''' """
Load tables from different files and serve a sql database for database operations Load tables from different files and serve a sql database for database operations
''' """
def __init__(self, files: str,
sql_path:str='sqlite:///mydatabase.db', def __init__(self, files: str, sql_path: str = "sqlite:///mydatabase.db", verbose=False, **kwargs) -> None:
verbose=False, **kwargs) -> None: """
'''
Args: Args:
files: list of files (list[file path, name]) files: list of files (list[file path, name])
sql_path: how to serve the sql database sql_path: how to serve the sql database
**kwargs: keyword type arguments, useful for certain document types **kwargs: keyword type arguments, useful for certain document types
''' """
self.data = {} self.data = {}
self.verbose = verbose self.verbose = verbose
self.sql_path = sql_path self.sql_path = sql_path
@ -49,58 +50,58 @@ class TableLoader:
self.to_sql(path, dataset_name) self.to_sql(path, dataset_name)
def load_data(self, path): def load_data(self, path):
''' """
Load data and serve the data as sql database. Load data and serve the data as sql database.
Data must be in pandas format Data must be in pandas format
''' """
files = [] files = []
# Handle glob expression # Handle glob expression
try: try:
files = glob.glob(path) files = glob.glob(path)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
if len(files)==0: if len(files) == 0:
raise ValueError("Unsupported file/directory format. For directories, please use glob expression") raise ValueError("Unsupported file/directory format. For directories, please use glob expression")
elif len(files)==1: elif len(files) == 1:
path = files[0] path = files[0]
else: else:
for file in files: for file in files:
self.load_data(file) self.load_data(file)
if path.endswith('.csv'): if path.endswith(".csv"):
# Load csv # Load csv
self.data[path] = pd.read_csv(path) self.data[path] = pd.read_csv(path)
elif path.endswith('.xlsx') or path.endswith('.xls'): elif path.endswith(".xlsx") or path.endswith(".xls"):
# Load excel # Load excel
self.data[path] = pd.read_excel(path) # You can adjust the sheet_name as needed self.data[path] = pd.read_excel(path) # You can adjust the sheet_name as needed
elif path.endswith('.json'): elif path.endswith(".json"):
# Load json # Load json
self.data[path] = pd.read_json(path) self.data[path] = pd.read_json(path)
elif path.endswith('.html'): elif path.endswith(".html"):
# Load html # Load html
html_tables = pd.read_html(path) html_tables = pd.read_html(path)
# Choose the desired table from the list of DataFrame objects # Choose the desired table from the list of DataFrame objects
self.data[path] = html_tables[0] # You may need to adjust this index self.data[path] = html_tables[0] # You may need to adjust this index
elif path.endswith('.h5') or path.endswith('.hdf5'): elif path.endswith(".h5") or path.endswith(".hdf5"):
# Load h5 # Load h5
self.data[path] = pd.read_hdf(path, key=self.kwargs.get('key', 'data')) # You can adjust the key as needed self.data[path] = pd.read_hdf(path, key=self.kwargs.get("key", "data")) # You can adjust the key as needed
elif path.endswith('.parquet'): elif path.endswith(".parquet"):
# Load parquet # Load parquet
self.data[path] = pd.read_parquet(path, engine='fastparquet') self.data[path] = pd.read_parquet(path, engine="fastparquet")
elif path.endswith('.feather'): elif path.endswith(".feather"):
# Load feather # Load feather
self.data[path] = pd.read_feather(path) self.data[path] = pd.read_feather(path)
elif path.endswith('.dta'): elif path.endswith(".dta"):
# Load dta # Load dta
self.data[path] = pd.read_stata(path) self.data[path] = pd.read_stata(path)
else: else:
raise ValueError("Unsupported file format") raise ValueError("Unsupported file format")
def to_sql(self, path, table_name): def to_sql(self, path, table_name):
''' """
Serve the data as sql database. Serve the data as sql database.
''' """
self.data[path].to_sql(table_name, con=self.sql_engine, if_exists='replace', index=False) self.data[path].to_sql(table_name, con=self.sql_engine, if_exists="replace", index=False)
logger.info(f"Loaded to Sqlite3\nPath: {path}", verbose=self.verbose) logger.info(f"Loaded to Sqlite3\nPath: {path}", verbose=self.verbose)
return self.sql_path return self.sql_path
@ -113,7 +114,3 @@ class TableLoader:
self.sql_engine.dispose() self.sql_engine.dispose()
del self.data del self.data
del self.sql_engine del self.sql_engine

@ -21,7 +21,7 @@ print(resp) # super-heavyweight awesome-natured yawning Australian creature!
""" """
import json import json
from typing import Any, List, Mapping, Optional from typing import Any, Mapping
import requests import requests
from langchain.llms.base import LLM from langchain.llms.base import LLM
@ -33,11 +33,11 @@ class ColossalCloudLLM(LLM):
A custom LLM class that integrates LLMs running on the ColossalCloud Platform A custom LLM class that integrates LLMs running on the ColossalCloud Platform
""" """
n: int n: int
gen_config: dict = None gen_config: dict = None
auth_config: dict = None auth_config: dict = None
valid_gen_para: list = ['max_new_tokens', 'top_k', valid_gen_para: list = ["max_new_tokens", "top_k", "top_p", "temperature", "repetition_penalty"]
'top_p', 'temperature', 'repetition_penalty']
def __init__(self, gen_config=None, **kwargs): def __init__(self, gen_config=None, **kwargs):
""" """
@ -63,15 +63,15 @@ class ColossalCloudLLM(LLM):
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return 'ColossalCloudLLM' return "ColossalCloudLLM"
def set_auth_config(self, **kwargs): def set_auth_config(self, **kwargs):
url = get_from_dict_or_env(kwargs, "url", "URL") url = get_from_dict_or_env(kwargs, "url", "URL")
host = get_from_dict_or_env(kwargs, "host", "HOST") host = get_from_dict_or_env(kwargs, "host", "HOST")
auth_config = {} auth_config = {}
auth_config['endpoint'] = url auth_config["endpoint"] = url
auth_config['Host'] = host auth_config["Host"] = host
self.auth_config = auth_config self.auth_config = auth_config
def _call(self, prompt: str, stop=None, **kwargs: Any) -> str: def _call(self, prompt: str, stop=None, **kwargs: Any) -> str:
@ -86,7 +86,9 @@ class ColossalCloudLLM(LLM):
# Update the generation arguments # Update the generation arguments
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.valid_gen_para: if key not in self.valid_gen_para:
raise KeyError(f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}") raise KeyError(
f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}"
)
if key in self.gen_config: if key in self.gen_config:
self.gen_config[key] = value self.gen_config[key] = value
@ -98,26 +100,16 @@ class ColossalCloudLLM(LLM):
resp_text = resp_text.split(stopping_words)[0] resp_text = resp_text.split(stopping_words)[0]
return resp_text return resp_text
def text_completion(self, prompt, gen_config, auth_config): def text_completion(self, prompt, gen_config, auth_config):
# Required Parameters # Required Parameters
endpoint = auth_config.pop('endpoint') endpoint = auth_config.pop("endpoint")
max_new_tokens = gen_config.pop('max_new_tokens') max_new_tokens = gen_config.pop("max_new_tokens")
# Optional Parameters # Optional Parameters
optional_params = ['top_k', 'top_p', 'temperature', 'repetition_penalty'] # Self.optional optional_params = ["top_k", "top_p", "temperature", "repetition_penalty"] # Self.optional
gen_config = {key: gen_config[key] for key in optional_params if key in gen_config} gen_config = {key: gen_config[key] for key in optional_params if key in gen_config}
# Define the data payload # Define the data payload
data = { data = {"max_new_tokens": max_new_tokens, "history": [{"instruction": prompt, "response": ""}], **gen_config}
"max_new_tokens": max_new_tokens, headers = {"Content-Type": "application/json", **auth_config} # 'Host',
"history": [
{"instruction": prompt, "response": ""}
],
**gen_config
}
headers = {
"Content-Type": "application/json",
**auth_config # 'Host',
}
# Make the POST request # Make the POST request
response = requests.post(endpoint, headers=headers, data=json.dumps(data)) response = requests.post(endpoint, headers=headers, data=json.dumps(data))
response.raise_for_status() # raise error if return code is not 200(success) response.raise_for_status() # raise error if return code is not 200(success)

@ -193,4 +193,3 @@ class VllmLLM(LLM):
def _identifying_params(self) -> Mapping[str, int]: def _identifying_params(self) -> Mapping[str, int]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {"n": self.n} return {"n": self.n}

@ -4,7 +4,6 @@ All custom prompt templates are defined here.
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
# Below are Chinese retrieval qa prompts # Below are Chinese retrieval qa prompts
_CUSTOM_SUMMARIZER_TEMPLATE_ZH = """请递进式地总结所提供的当前对话,将当前对话的摘要内容添加到先前已有的摘要上,返回一个融合了当前对话的新的摘要。 _CUSTOM_SUMMARIZER_TEMPLATE_ZH = """请递进式地总结所提供的当前对话,将当前对话的摘要内容添加到先前已有的摘要上,返回一个融合了当前对话的新的摘要。

@ -99,13 +99,7 @@ class CustomRetriever(BaseRetriever):
def clear_documents(self): def clear_documents(self):
"""Clear all document vectors from database""" """Clear all document vectors from database"""
for source in self.vector_stores: for source in self.vector_stores:
index( index([], self.record_managers[source], self.vector_stores[source], cleanup="full", source_id_key="source")
[],
self.record_managers[source],
self.vector_stores[source],
cleanup="full",
source_id_key="source"
)
self.vector_stores = {} self.vector_stores = {}
self.sql_index_database = {} self.sql_index_database = {}
self.record_managers = {} self.record_managers = {}

@ -1,22 +1,27 @@
import argparse import argparse
from colossalqa.retrieval_conversation_universal import UniversalRetrievalConversation from colossalqa.retrieval_conversation_universal import UniversalRetrievalConversation
if __name__ == '__main__': if __name__ == "__main__":
# Parse arguments # Parse arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--en_model_path', type=str, default=None) parser.add_argument("--en_model_path", type=str, default=None)
parser.add_argument('--zh_model_path', type=str, default=None) parser.add_argument("--zh_model_path", type=str, default=None)
parser.add_argument('--zh_model_name', type=str, default=None) parser.add_argument("--zh_model_name", type=str, default=None)
parser.add_argument('--en_model_name', type=str, default=None) parser.add_argument("--en_model_name", type=str, default=None)
parser.add_argument('--sql_file_path', type=str, default=None, help='path to the a empty folder for storing sql files for indexing') parser.add_argument(
"--sql_file_path", type=str, default=None, help="path to the a empty folder for storing sql files for indexing"
)
args = parser.parse_args() args = parser.parse_args()
# Will ask for documents path in running time # Will ask for documents path in running time
session = UniversalRetrievalConversation(files_en=None, session = UniversalRetrievalConversation(
files_en=None,
files_zh=None, files_zh=None,
zh_model_path=args.zh_model_path, en_model_path=args.en_model_path, zh_model_path=args.zh_model_path,
zh_model_name=args.zh_model_name, en_model_name=args.en_model_name, en_model_path=args.en_model_path,
sql_file_path=args.sql_file_path zh_model_name=args.zh_model_name,
en_model_name=args.en_model_name,
sql_file_path=args.sql_file_path,
) )
session.start_test_session() session.start_test_session()

@ -5,13 +5,7 @@ from colossalqa.chain.retrieval_qa.base import RetrievalQA
from colossalqa.data_loader.document_loader import DocumentLoader from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.memory import ConversationBufferWithSummary from colossalqa.memory import ConversationBufferWithSummary
from colossalqa.mylogging import get_logger from colossalqa.mylogging import get_logger
from colossalqa.prompt.prompt import ( from colossalqa.prompt.prompt import ZH_RETRIEVAL_QA_REJECTION_ANSWER, ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS
PROMPT_DISAMBIGUATE_ZH,
PROMPT_RETRIEVAL_QA_ZH,
SUMMARY_PROMPT_ZH,
ZH_RETRIEVAL_QA_REJECTION_ANSWER,
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
)
from colossalqa.retriever import CustomRetriever from colossalqa.retriever import CustomRetriever
from langchain import LLMChain from langchain import LLMChain
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings

@ -1,58 +1,30 @@
from colossalqa.prompt.prompt import ( from colossalqa.prompt.prompt import PROMPT_DISAMBIGUATE_ZH, PROMPT_RETRIEVAL_QA_ZH, SUMMARY_PROMPT_ZH
PROMPT_DISAMBIGUATE_ZH,
PROMPT_RETRIEVAL_QA_ZH,
SUMMARY_PROMPT_ZH,
ZH_RETRIEVAL_QA_REJECTION_ANSWER,
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
)
from colossalqa.text_splitter import ChineseTextSplitter from colossalqa.text_splitter import ChineseTextSplitter
ALL_CONFIG = { ALL_CONFIG = {
"embed": { "embed": {
"embed_name": "m3e", # embedding model name "embed_name": "m3e", # embedding model name
"embed_model_name_or_path": "moka-ai/m3e-base", # path to embedding model, could be a local path or a huggingface path "embed_model_name_or_path": "moka-ai/m3e-base", # path to embedding model, could be a local path or a huggingface path
"embed_model_device": { "embed_model_device": {"device": "cpu"},
"device": "cpu"
}
}, },
"model": { "model": {
"mode": "api", # "local" for loading models, "api" for using model api "mode": "api", # "local" for loading models, "api" for using model api
"model_name": "chatgpt_api", # local model name, "chatgpt_api" or "pangu_api" "model_name": "chatgpt_api", # local model name, "chatgpt_api" or "pangu_api"
"model_path": "", # path to the model, could be a local path or a huggingface path. don't need if using an api "model_path": "", # path to the model, could be a local path or a huggingface path. don't need if using an api
"device": { "device": {"device": "cuda"},
"device": "cuda"
}
},
"splitter": {
"name": ChineseTextSplitter
},
"retrieval": {
"retri_top_k": 3,
"retri_kb_file_path": "./", # path to store database files
"verbose": True
}, },
"splitter": {"name": ChineseTextSplitter},
"retrieval": {"retri_top_k": 3, "retri_kb_file_path": "./", "verbose": True}, # path to store database files
"chain": { "chain": {
"mem_summary_prompt": SUMMARY_PROMPT_ZH, # summary prompt template "mem_summary_prompt": SUMMARY_PROMPT_ZH, # summary prompt template
"mem_human_prefix": "用户", "mem_human_prefix": "用户",
"mem_ai_prefix": "Assistant", "mem_ai_prefix": "Assistant",
"mem_max_tokens": 2000, "mem_max_tokens": 2000,
"mem_llm_kwargs": { "mem_llm_kwargs": {"max_new_tokens": 50, "temperature": 1, "do_sample": True},
"max_new_tokens": 50,
"temperature": 1,
"do_sample": True
},
"disambig_prompt": PROMPT_DISAMBIGUATE_ZH, # disambiguate prompt template "disambig_prompt": PROMPT_DISAMBIGUATE_ZH, # disambiguate prompt template
"disambig_llm_kwargs": { "disambig_llm_kwargs": {"max_new_tokens": 30, "temperature": 1, "do_sample": True},
"max_new_tokens": 30, "gen_llm_kwargs": {"max_new_tokens": 100, "temperature": 1, "do_sample": True},
"temperature": 1,
"do_sample": True
},
"gen_llm_kwargs": {
"max_new_tokens": 100,
"temperature": 1,
"do_sample": True
},
"gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH, # generation prompt template "gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH, # generation prompt template
"verbose": True "verbose": True,
} },
} }

@ -1,27 +1,18 @@
import argparse import argparse
import os
from typing import List, Union from typing import List, Union
import config
import uvicorn
from colossalqa.local.llm import ColossalAPI, ColossalLLM from colossalqa.local.llm import ColossalAPI, ColossalLLM
from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.mylogging import get_logger from colossalqa.mylogging import get_logger
from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
from colossalqa.retriever import CustomRetriever
from enum import Enum
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from langchain.embeddings import HuggingFaceEmbeddings from pydantic import BaseModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pydantic import BaseModel, Field
import uvicorn
import config
from RAG_ChatBot import RAG_ChatBot from RAG_ChatBot import RAG_ChatBot
from utils import DocAction from utils import DocAction
logger = get_logger() logger = get_logger()
def parseArgs(): def parseArgs():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--http_host", default="0.0.0.0") parser.add_argument("--http_host", default="0.0.0.0")
@ -36,6 +27,7 @@ class DocUpdateReq(BaseModel):
doc_files: Union[List[str], str, None] = None doc_files: Union[List[str], str, None] = None
action: DocAction = DocAction.ADD action: DocAction = DocAction.ADD
class GenerationTaskReq(BaseModel): class GenerationTaskReq(BaseModel):
user_input: str user_input: str
@ -45,7 +37,7 @@ def update_docs(data: DocUpdateReq, request: Request):
if data.action == "add": if data.action == "add":
if isinstance(data.doc_files, str): if isinstance(data.doc_files, str):
data.doc_files = [data.doc_files] data.doc_files = [data.doc_files]
chatbot.load_doc_from_files(files = data.doc_files) chatbot.load_doc_from_files(files=data.doc_files)
all_docs = "" all_docs = ""
for doc in chatbot.docs_names: for doc in chatbot.docs_names:
all_docs += f"\t{doc}\n\n" all_docs += f"\t{doc}\n\n"
@ -84,12 +76,13 @@ if __name__ == "__main__":
"user": "User", "user": "User",
"max_tokens": all_config["chain"]["disambig_llm_kwargs"]["max_new_tokens"], "max_tokens": all_config["chain"]["disambig_llm_kwargs"]["max_new_tokens"],
"temperature": all_config["chain"]["disambig_llm_kwargs"]["temperature"], "temperature": all_config["chain"]["disambig_llm_kwargs"]["temperature"],
"n": 1 # the number of responses generated "n": 1, # the number of responses generated
} }
llm = Pangu(gen_config=gen_config) llm = Pangu(gen_config=gen_config)
llm.set_auth_config() # verify user's auth info here llm.set_auth_config() # verify user's auth info here
elif model_name == "chatgpt_api": elif model_name == "chatgpt_api":
from langchain.llms import OpenAI from langchain.llms import OpenAI
llm = OpenAI() llm = OpenAI()
else: else:
raise ValueError("Unsupported mode.") raise ValueError("Unsupported mode.")

@ -1,24 +1,26 @@
import argparse import argparse
import json import json
import os import os
import requests
import gradio as gr import gradio as gr
import requests
from utils import DocAction from utils import DocAction
def parseArgs(): def parseArgs():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--http_host", default="0.0.0.0") parser.add_argument("--http_host", default="0.0.0.0")
parser.add_argument("--http_port", type=int, default=13666) parser.add_argument("--http_port", type=int, default=13666)
return parser.parse_args() return parser.parse_args()
def get_response(data, url): def get_response(data, url):
headers = {"Content-type": "application/json"} headers = {"Content-type": "application/json"}
response = requests.post(url, json=data, headers=headers) response = requests.post(url, json=data, headers=headers)
response = json.loads(response.content) response = json.loads(response.content)
return response return response
def add_text(history, text): def add_text(history, text):
history = history + [(text, None)] history = history + [(text, None)]
return history, gr.update(value=None, interactive=True) return history, gr.update(value=None, interactive=True)
@ -28,18 +30,14 @@ def add_file(history, files):
files_string = "\n".join([os.path.basename(file.name) for file in files]) files_string = "\n".join([os.path.basename(file.name) for file in files])
doc_files = [file.name for file in files] doc_files = [file.name for file in files]
data = { data = {"doc_files": doc_files, "action": DocAction.ADD}
"doc_files": doc_files,
"action": DocAction.ADD
}
response = get_response(data, update_url)["response"] response = get_response(data, update_url)["response"]
history = history + [(files_string, response)] history = history + [(files_string, response)]
return history return history
def bot(history): def bot(history):
data = { data = {"user_input": history[-1][0].strip()}
"user_input": history[-1][0].strip()
}
response = get_response(data, gen_url) response = get_response(data, gen_url)
if response["error"] != "": if response["error"] != "":
@ -51,11 +49,8 @@ def bot(history):
def restart(chatbot, txt): def restart(chatbot, txt):
# Reset the conversation state and clear the chat history # Reset the conversation state and clear the chat history
data = { data = {"doc_files": "", "action": DocAction.CLEAR}
"doc_files": "", get_response(data, update_url)
"action": DocAction.CLEAR
}
response = get_response(data, update_url)
return gr.update(value=None), gr.update(value=None, interactive=True) return gr.update(value=None), gr.update(value=None, interactive=True)

@ -1,21 +1,21 @@
import os import os
from colossalqa.data_loader.document_loader import DocumentLoader from colossalqa.data_loader.document_loader import DocumentLoader
def test_add_document(): def test_add_document():
PATH = os.environ.get('TEST_DOCUMENT_LOADER_DATA_PATH') PATH = os.environ.get("TEST_DOCUMENT_LOADER_DATA_PATH")
files = [[PATH, 'all data']] files = [[PATH, "all data"]]
document_loader = DocumentLoader(files) document_loader = DocumentLoader(files)
documents = document_loader.all_data documents = document_loader.all_data
all_files = [] all_files = []
for doc in documents: for doc in documents:
assert isinstance(doc.page_content, str)==True assert isinstance(doc.page_content, str) == True
if doc.metadata['source'] not in all_files: if doc.metadata["source"] not in all_files:
all_files.append(doc.metadata['source']) all_files.append(doc.metadata["source"])
print(all_files) print(all_files)
assert len(all_files) == 6 assert len(all_files) == 6
if __name__=='__main__': if __name__ == "__main__":
test_add_document() test_add_document()

@ -4,56 +4,44 @@ from colossalqa.retrieval_conversation_universal import UniversalRetrievalConver
def test_en_retrievalQA(): def test_en_retrievalQA():
data_path_en = os.environ.get('TEST_DATA_PATH_EN') data_path_en = os.environ.get("TEST_DATA_PATH_EN")
data_path_zh = os.environ.get('TEST_DATA_PATH_ZH') data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
en_model_path = os.environ.get('EN_MODEL_PATH') en_model_path = os.environ.get("EN_MODEL_PATH")
zh_model_path = os.environ.get('ZH_MODEL_PATH') zh_model_path = os.environ.get("ZH_MODEL_PATH")
zh_model_name = os.environ.get('ZH_MODEL_NAME') zh_model_name = os.environ.get("ZH_MODEL_NAME")
en_model_name = os.environ.get('EN_MODEL_NAME') en_model_name = os.environ.get("EN_MODEL_NAME")
sql_file_path = os.environ.get('SQL_FILE_PATH') sql_file_path = os.environ.get("SQL_FILE_PATH")
qa_session = UniversalRetrievalConversation(files_en=[{ qa_session = UniversalRetrievalConversation(
'data_path': data_path_en, files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}],
'name': 'company information', files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}],
'separator': '\n'
}],
files_zh=[{
'data_path': data_path_zh,
'name': 'company information',
'separator': '\n'
}],
zh_model_path=zh_model_path, zh_model_path=zh_model_path,
en_model_path=en_model_path, en_model_path=en_model_path,
zh_model_name=zh_model_name, zh_model_name=zh_model_name,
en_model_name=en_model_name, en_model_name=en_model_name,
sql_file_path=sql_file_path) sql_file_path=sql_file_path,
ans = qa_session.run("which company runs business in hotel industry?", which_language='en') )
ans = qa_session.run("which company runs business in hotel industry?", which_language="en")
print(ans) print(ans)
def test_zh_retrievalQA(): def test_zh_retrievalQA():
data_path_en = os.environ.get('TEST_DATA_PATH_EN') data_path_en = os.environ.get("TEST_DATA_PATH_EN")
data_path_zh = os.environ.get('TEST_DATA_PATH_ZH') data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
en_model_path = os.environ.get('EN_MODEL_PATH') en_model_path = os.environ.get("EN_MODEL_PATH")
zh_model_path = os.environ.get('ZH_MODEL_PATH') zh_model_path = os.environ.get("ZH_MODEL_PATH")
zh_model_name = os.environ.get('ZH_MODEL_NAME') zh_model_name = os.environ.get("ZH_MODEL_NAME")
en_model_name = os.environ.get('EN_MODEL_NAME') en_model_name = os.environ.get("EN_MODEL_NAME")
sql_file_path = os.environ.get('SQL_FILE_PATH') sql_file_path = os.environ.get("SQL_FILE_PATH")
qa_session = UniversalRetrievalConversation(files_en=[{ qa_session = UniversalRetrievalConversation(
'data_path': data_path_en, files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}],
'name': 'company information', files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}],
'separator': '\n'
}],
files_zh=[{
'data_path': data_path_zh,
'name': 'company information',
'separator': '\n'
}],
zh_model_path=zh_model_path, zh_model_path=zh_model_path,
en_model_path=en_model_path, en_model_path=en_model_path,
zh_model_name=zh_model_name, zh_model_name=zh_model_name,
en_model_name=en_model_name, en_model_name=en_model_name,
sql_file_path=sql_file_path) sql_file_path=sql_file_path,
ans = qa_session.run("哪家公司在经营酒店业务?", which_language='zh') )
ans = qa_session.run("哪家公司在经营酒店业务?", which_language="zh")
print(ans) print(ans)

@ -1,5 +1,5 @@
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
from . import accelerator from . import accelerator
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
try: try:
# .version will be created by setup.py # .version will be created by setup.py

@ -27,7 +27,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -93,9 +93,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
Path(checkpoint_path).mkdir(parents=True, exist_ok=True) Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
with FSDP.state_dict_type( with FSDP.state_dict_type(
model.unwrap(), model.unwrap(), StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
): ):
state_dict = model.unwrap().state_dict() state_dict = model.unwrap().state_dict()
@ -172,7 +170,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
with FSDP.state_dict_type( with FSDP.state_dict_type(
optimizer.unwrap_model().unwrap(), optimizer.unwrap_model().unwrap(),
StateDictType.FULL_STATE_DICT, StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True) FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
): ):
fsdp_optim_state = FSDP.full_optim_state_dict( fsdp_optim_state = FSDP.full_optim_state_dict(
optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
@ -241,7 +239,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
) )
optimizer.load_state_dict(fsdp_state) optimizer.load_state_dict(fsdp_state)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.

@ -294,6 +294,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# Helper functions for saving state dict # Helper functions for saving state dict
# ====================================== # ======================================
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
""" """
Save state dict to checkpoint. Save state dict to checkpoint.

@ -174,9 +174,13 @@ class ProcessGroupMesh:
List[Tuple[int, ...]]: Coordinates along the axis. List[Tuple[int, ...]]: Coordinates along the axis.
""" """
if isinstance(axis, int): if isinstance(axis, int):
axis = [axis,] axis = [
axis,
]
assert isinstance(indices_at_axis[0], int) assert isinstance(indices_at_axis[0], int)
indices_at_axis = [indices_at_axis,] indices_at_axis = [
indices_at_axis,
]
def add_index(base_coord, axis, indices_at_axis): def add_index(base_coord, axis, indices_at_axis):
coords_in_group = [] coords_in_group = []
@ -194,7 +198,10 @@ class ProcessGroupMesh:
return coords_in_group return coords_in_group
def create_group_along_axis( def create_group_along_axis(
self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None self,
axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None,
) -> ProcessGroup: ) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to. """Create all process groups along the given axis, and return the one which the current process belongs to.
@ -207,10 +214,14 @@ class ProcessGroupMesh:
ProcessGroup: The process group along the given axis which the current process belongs to. ProcessGroup: The process group along the given axis which the current process belongs to.
""" """
if isinstance(axis, int): if isinstance(axis, int):
axis = [axis,] axis = [
axis,
]
if indices_at_axis is not None: if indices_at_axis is not None:
assert isinstance(indices_at_axis[0], int) assert isinstance(indices_at_axis[0], int)
indices_at_axis = [indices_at_axis,] indices_at_axis = [
indices_at_axis,
]
indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis] indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
reduced_shape = list(self._shape) reduced_shape = list(self._shape)

@ -29,13 +29,17 @@ except:
try: try:
from colossalai.kernel.triton.flash_decoding import token_flash_decoding from colossalai.kernel.triton.flash_decoding import token_flash_decoding
HAS_TRITON_FLASH_DECODING_KERNEL = True HAS_TRITON_FLASH_DECODING_KERNEL = True
except: except:
print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") print(
"no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8"
)
HAS_TRITON_FLASH_DECODING_KERNEL = False HAS_TRITON_FLASH_DECODING_KERNEL = False
try: try:
from flash_attn import flash_attn_with_kvcache from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True HAS_FLASH_KERNEL = True
except: except:
HAS_FLASH_KERNEL = False HAS_FLASH_KERNEL = False
@ -48,6 +52,7 @@ def rotate_half(x):
x2 = x[..., x.shape[-1] // 2 :] x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
@ -96,15 +101,20 @@ def llama_triton_context_attention(
infer_state.max_len_in_batch, infer_state.max_len_in_batch,
) )
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
def llama_triton_token_attention(
query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num=-1, head_dim=-1
):
if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1: if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
token_flash_decoding(q = query_states, token_flash_decoding(
o_tensor = attn_output, q=query_states,
infer_state = infer_state, o_tensor=attn_output,
q_head_num = q_head_num, infer_state=infer_state,
head_dim = head_dim, q_head_num=q_head_num,
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], head_dim=head_dim,
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]) cache_k=infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
cache_v=infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
)
return return
if num_key_value_groups == 1: if num_key_value_groups == 1:
@ -459,14 +469,15 @@ class LlamaInferenceForwards:
) )
if HAS_LIGHTLLM_KERNEL: if HAS_LIGHTLLM_KERNEL:
attn_output = torch.empty_like(query_states) attn_output = torch.empty_like(query_states)
llama_triton_token_attention(query_states = query_states, llama_triton_token_attention(
attn_output = attn_output, query_states=query_states,
infer_state = infer_state, attn_output=attn_output,
num_key_value_groups = self.num_key_value_groups, infer_state=infer_state,
q_head_num = q_len * self.num_heads, num_key_value_groups=self.num_key_value_groups,
head_dim = self.head_dim) q_head_num=q_len * self.num_heads,
head_dim=self.head_dim,
)
else: else:
self.num_heads // self.num_key_value_heads self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]

@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
try: try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load() gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True HAS_GPTQ_CUDA = True
except ImportError: except ImportError:
warnings.warn('CUDA gptq is not installed') warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
class CaiQuantLinear(nn.Module): class CaiQuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__() super().__init__()
if bits not in [2, 4, 8]: if bits not in [2, 4, 8]:
@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module):
self.maxq = 2**self.bits - 1 self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures self.groupsize = groupsize if groupsize != -1 else infeatures
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer(
"qzeros",
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32),
)
self.register_buffer( self.register_buffer(
'qzeros', "scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) )
self.register_buffer('scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
if row_split: if row_split:
self.register_buffer( self.register_buffer(
'g_idx', "g_idx",
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], torch.tensor(
dtype=torch.int32)) [(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32
),
)
else: else:
self.register_buffer('g_idx', self.register_buffer(
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) "g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
)
if bias: if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
else: else:
self.bias = None self.bias = None
@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module):
self.row_split = row_split self.row_split = row_split
def pack(self, linear, scales, zeros, g_idx=None): def pack(self, linear, scales, zeros, g_idx=None):
g_idx = (
g_idx = g_idx.clone() if g_idx is not None else torch.tensor( g_idx.clone()
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) if g_idx is not None
else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
)
scales = scales.t().contiguous() scales = scales.t().contiguous()
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()
@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module):
if linear.bias is not None: if linear.bias is not None:
self.bias = linear.bias.clone().half() self.bias = linear.bias.clone().half()
wn = 8
pbits = 32 pbits = 32
ptype = torch.int32 ptype = torch.int32
unsign_type = np.uint32 unsign_type = np.uint32
@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module):
intweight = [] intweight = []
for idx in range(self.infeatures): for idx in range(self.infeatures):
intweight.append( intweight.append(
torch.round( torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, :, None
None]) ]
)
intweight = torch.cat(intweight, dim=1) intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous() intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(unsign_type) intweight = intweight.numpy().astype(unsign_type)
@ -109,7 +116,7 @@ class CaiQuantLinear(nn.Module):
raise NotImplementedError("Only 2,4,8 bits are supported.") raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(sign_type) qweight = qweight.astype(sign_type)
qweight1 = torch.from_numpy(qweight) qweight1 = torch.from_numpy(qweight)
qweight1 = qweight1.contiguous() #.to("cuda") qweight1 = qweight1.contiguous() # .to("cuda")
self.qweight.data.copy_(qweight1) self.qweight.data.copy_(qweight1)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
@ -144,13 +151,16 @@ class CaiQuantLinear(nn.Module):
torch.tensor( torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32, dtype=torch.int32,
device=self.g_idx.device)): device=self.g_idx.device,
),
):
self.g_idx = None self.g_idx = None
elif torch.equal( elif torch.equal(
self.g_idx, self.g_idx,
torch.tensor([i // self.groupsize for i in range(self.infeatures)], torch.tensor(
dtype=torch.int32, [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device
device=self.g_idx.device)): ),
):
self.g_idx = None self.g_idx = None
if self.g_idx is not None: if self.g_idx is not None:
@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module):
outshape = x.shape[:-1] + (self.outfeatures,) outshape = x.shape[:-1] + (self.outfeatures,)
if HAS_GPTQ_CUDA and self.bits == 4: if HAS_GPTQ_CUDA and self.bits == 4:
if self.q4 is None: if self.q4 is None:
self.init_q4() self.init_q4()
@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module):
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
for i in range(split_num): for i in range(split_num):
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) * cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
cai_split_out_features] ]
cai_linear.qzeros[:, i * zero_split_block:(i + 1) * cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block] :, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block
cai_linear.scales[:, i * cai_split_out_features:(i + 1) * ]
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][
cai_split_out_features] :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
]
if cai_linear.bias is not None: if cai_linear.bias is not None:
cai_linear.bias[i * cai_split_out_features:(i + 1) * cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) * tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
cai_split_out_features] ]
cai_linear.g_idx.copy_(g_idx) cai_linear.g_idx.copy_(g_idx)
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
idx_split_features = cai_linear.infeatures // split_num idx_split_features = cai_linear.infeatures // split_num
for i in range(split_num): for i in range(split_num):
cai_linear.qweight[i * cai_split_in_features:(i + 1) * cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) * tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, :
cai_split_in_features, :] ]
cai_linear.qzeros[i * zero_split_block:(i + 1) * cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) * tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
zero_split_block, :] ]
cai_linear.scales[i * zero_split_block:(i + 1) * cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) * tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
zero_split_block, :] ]
cai_linear.g_idx[i * idx_split_features:(i + 1) * cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) * tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features
idx_split_features] ]
if cai_linear.bias is not None: if cai_linear.bias is not None:
cai_linear.bias.copy_(gptq_linear.bias) cai_linear.bias.copy_(gptq_linear.bias)
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(
super().__init__(bits, bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
groupsize, )
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
self.process_group = None self.process_group = None
@staticmethod @staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, def from_native_module(
**kwargs) -> ParallelModule: module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# get the attributes # get the attributes
in_features = module.in_features in_features = module.in_features
# ensure only one process group is passed # ensure only one process group is passed
if isinstance(process_group, (list, tuple)): if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0] process_group = process_group[0]
tp_size = dist.get_world_size(process_group) tp_size = dist.get_world_size(process_group)
@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0: if in_features % tp_size != 0:
raise ValueError( raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
linear_1d = RowCaiQuantLinear(module.bits, )
linear_1d = RowCaiQuantLinear(
module.bits,
module.group_size, module.group_size,
module.in_features // tp_size, module.in_features // tp_size,
module.out_features, module.out_features,
module.bias is not None, module.bias is not None,
tp_size=tp_size, tp_size=tp_size,
tp_rank=tp_rank, tp_rank=tp_rank,
row_split=True) row_split=True,
)
linear_1d.process_group = process_group linear_1d.process_group = process_group
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(
super().__init__(bits, bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
groupsize, )
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
self.process_group = None self.process_group = None
@staticmethod @staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, def from_native_module(
**kwargs) -> ParallelModule: module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# get the attributes # get the attributes
in_features = module.in_features in_features = module.in_features
# ensure only one process group is passed # ensure only one process group is passed
if isinstance(process_group, (list, tuple)): if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0] process_group = process_group[0]
tp_size = dist.get_world_size(process_group) tp_size = dist.get_world_size(process_group)
@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0: if in_features % tp_size != 0:
raise ValueError( raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
linear_1d = ColCaiQuantLinear(module.bits, )
linear_1d = ColCaiQuantLinear(
module.bits,
module.group_size, module.group_size,
module.in_features, module.in_features,
module.out_features // tp_size, module.out_features // tp_size,
module.bias is not None, module.bias is not None,
tp_size=tp_size, tp_size=tp_size,
tp_rank=tp_rank) tp_rank=tp_rank,
)
linear_1d.process_group = process_group linear_1d.process_group = process_group
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)

@ -5,6 +5,7 @@ import torch
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
@ -16,6 +17,7 @@ if HAS_TRITON:
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
""" """
if triton.__version__ < "2.1.0": if triton.__version__ < "2.1.0":
@triton.jit @triton.jit
def _context_flash_attention_kernel( def _context_flash_attention_kernel(
Q, Q,
@ -131,23 +133,41 @@ if HAS_TRITON:
m_i = m_i_new m_i = m_i_new
off_o = ( off_o = (
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od (cur_batch_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
) )
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return return
else: else:
# this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11 # this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
@triton.jit @triton.jit
def _context_flash_attention_kernel_2( def _context_flash_attention_kernel_2(
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, Q,
K,
V,
sm_scale,
Alibi,
B_Start_Loc,
B_Seqlen,
Out, Out,
kv_group_num, kv_group_num,
stride_qbs, stride_qh, stride_qd, stride_qbs,
stride_kbs, stride_kh, stride_kd, stride_qh,
stride_vbs, stride_vh, stride_vd, stride_qd,
stride_obs, stride_oh, stride_od, stride_kbs,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
@ -166,7 +186,11 @@ if HAS_TRITON:
offs_n = tl.arange(0, BLOCK_N) offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL) offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
if kv_group_num is None or kv_group_num == 1: if kv_group_num is None or kv_group_num == 1:
off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
@ -191,8 +215,11 @@ if HAS_TRITON:
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ---- # -- compute qk ----
k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, k = tl.load(
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
@ -220,8 +247,11 @@ if HAS_TRITON:
acc_scale = l_i / l_i_new * alpha acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None] acc = acc * acc_scale[:, None]
# update acc # update acc
v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, v = tl.load(
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
p = p.to(v.dtype) p = p.to(v.dtype)
acc += tl.dot(p, v) acc += tl.dot(p, v)
@ -229,7 +259,11 @@ if HAS_TRITON:
l_i = l_i_new l_i = l_i_new
m_i = m_i_new m_i = m_i_new
# initialize pointers to output # initialize pointers to output
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return return
@ -286,7 +320,13 @@ if HAS_TRITON:
) )
else: else:
_context_flash_attention_kernel_2[grid]( _context_flash_attention_kernel_2[grid](
q, k, v, sm_scale, alibi, b_start_loc, b_seq_len, q,
k,
v,
sm_scale,
alibi,
b_start_loc,
b_seq_len,
o, o,
None, None,
q.stride(0), q.stride(0),
@ -388,6 +428,7 @@ if HAS_TRITON:
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
num_warps=num_warps, num_warps=num_warps,
num_stages=1,) num_stages=1,
)
return return

@ -1,8 +1,10 @@
# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py # adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
import torch import torch
try: try:
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1 from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
HAS_LIGHTLLM_KERNEL = True HAS_LIGHTLLM_KERNEL = True
except: except:
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
@ -10,31 +12,29 @@ except:
if HAS_LIGHTLLM_KERNEL: if HAS_LIGHTLLM_KERNEL:
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v): def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
BLOCK_SEQ = 256 BLOCK_SEQ = 256
batch_size = infer_state.batch_size batch_size = infer_state.batch_size
max_len_in_batch = infer_state.max_len_in_batch max_len_in_batch = infer_state.max_len_in_batch
calcu_shape1 = (batch_size, q_head_num, head_dim) calcu_shape1 = (batch_size, q_head_num, head_dim)
if getattr(infer_state, 'mid_o', None) is None: if getattr(infer_state, "mid_o", None) is None:
infer_state.mid_o = torch.empty([batch_size, infer_state.mid_o = torch.empty(
q_head_num, [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim],
max_len_in_batch // BLOCK_SEQ + 1,
head_dim],
dtype=torch.float32, dtype=torch.float32,
device="cuda") device="cuda",
infer_state.mid_o_logexpsum = torch.empty([batch_size, )
q_head_num, infer_state.mid_o_logexpsum = torch.empty(
max_len_in_batch // BLOCK_SEQ + 1], [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
dtype=torch.float32, )
device="cuda")
mid_o = infer_state.mid_o mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum mid_o_logexpsum = infer_state.mid_o_logexpsum
flash_decode_stage1(q.view(calcu_shape1), flash_decode_stage1(
q.view(calcu_shape1),
cache_k, cache_k,
cache_v, cache_v,
infer_state.block_loc, infer_state.block_loc,
@ -42,9 +42,6 @@ if HAS_LIGHTLLM_KERNEL:
infer_state.max_len_in_batch, infer_state.max_len_in_batch,
mid_o, mid_o,
mid_o_logexpsum, mid_o_logexpsum,
BLOCK_SEQ) BLOCK_SEQ,
flash_decode_stage2(mid_o, )
mid_o_logexpsum, flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
infer_state.seq_len,
o_tensor.view(calcu_shape1),
BLOCK_SEQ)

@ -8,6 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
@ -41,9 +42,9 @@ if HAS_TRITON:
for off in range(0, N, BLOCK_SIZE): for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE) cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
x_up = tl.load(X_UP + cols, mask=mask, other=0.) x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
# Write output # Write output
@ -76,10 +77,10 @@ if HAS_TRITON:
for off in range(0, N, BLOCK_SIZE): for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE) cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
x_up = tl.load(X_UP + cols, mask=mask, other=0.) x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.) y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.0)
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up # forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
@ -147,14 +148,9 @@ if HAS_TRITON:
# restore setting # restore setting
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
# enqueue kernel # enqueue kernel
_llama_act_combine_forward[(M,)](x_gate1, _llama_act_combine_forward[(M,)](
x_gate2, x_gate1, x_gate2, x_up, y, x_up.stride(-2), N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
x_up, )
y,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
return y return y
@staticmethod @staticmethod
@ -166,11 +162,15 @@ if HAS_TRITON:
# init grad # init grad
y_grad = grad_outputs[0] y_grad = grad_outputs[0]
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like( x_gate1_grad, x_gate2_grad, x_up_grad = (
x_gate2), torch.empty_like(x_up) torch.empty_like(x_gate1),
torch.empty_like(x_gate2),
torch.empty_like(x_up),
)
# enqueue kernel # enqueue kernel
_llama_act_combine_backward[(M,)](x_gate1, _llama_act_combine_backward[(M,)](
x_gate1,
x_gate2, x_gate2,
x_up, x_up,
x_gate1_grad, x_gate1_grad,
@ -180,6 +180,7 @@ if HAS_TRITON:
x_up.stride(-2), x_up.stride(-2),
N, N,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps) num_warps=num_warps,
)
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1) x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
return x_gate_grad, x_up_grad, None, None return x_gate_grad, x_up_grad, None, None

@ -13,10 +13,18 @@ except ImportError:
print("please install triton from https://github.com/openai/triton") print("please install triton from https://github.com/openai/triton")
try: try:
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2 from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import (
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd token_att_fwd as lightllm_bloom_token_att_fwd,
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd )
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import (
token_att_fwd as lightllm_llama_token_att_fwd,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import (
token_att_fwd2 as lightllm_llama_token_att_fwd2,
)
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import (
token_softmax_fwd as lightllm_llama_token_softmax_fwd,
)
HAS_TRITON_TOKEN_ATTENTION = True HAS_TRITON_TOKEN_ATTENTION = True
except ImportError: except ImportError:
@ -205,9 +213,7 @@ class Llama2TokenAttentionForwards:
if triton.__version__ == "2.0.0": if triton.__version__ == "2.0.0":
prob = torch.empty_like(att_m_tensor) prob = torch.empty_like(att_m_tensor)
lightllm_llama_token_softmax_fwd( lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
)
att_m_tensor = None att_m_tensor = None
lightllm_llama_token_att_fwd2( lightllm_llama_token_att_fwd2(

@ -8,7 +8,9 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from ._utils import copy_kv_to_mem_cache from ._utils import copy_kv_to_mem_cache
try: try:
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama_context_attention_fwd, context_attention_fwd as lightllm_llama_context_attention_fwd,

@ -1,5 +1,5 @@
from .attn import AttnMaskType, ColoAttention
from ._operation import all_to_all_comm from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row from .linear import Linear1D_Col, Linear1D_Row

@ -2,13 +2,13 @@ from .api import (
compute_global_numel, compute_global_numel,
customized_distributed_tensor_to_param, customized_distributed_tensor_to_param,
distribute_tensor, distribute_tensor,
init_as_dtensor,
distribute_tensor_with_customization, distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh, get_device_mesh,
get_global_shape, get_global_shape,
get_layout, get_layout,
get_sharding_spec, get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
is_sharded, is_sharded,

@ -128,7 +128,10 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
return sharded_tensor return sharded_tensor
def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor:
def init_as_dtensor(
tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size
) -> torch.Tensor:
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
@ -140,6 +143,7 @@ def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec
return tensor return tensor
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
""" """
Convert the layout of the tensor from source_spec to target_spec. Convert the layout of the tensor from source_spec to target_spec.
@ -468,7 +472,6 @@ def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gat
assert callable(gather_fn), "The gather_fn must be callable." assert callable(gather_fn), "The gather_fn must be callable."
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
# set the shard_fn and gather_fn as attributes of the distributed tensor # set the shard_fn and gather_fn as attributes of the distributed tensor
tensor.shard_fn = shard_fn tensor.shard_fn = shard_fn
tensor.gather_fn = gather_fn tensor.gather_fn = gather_fn

@ -190,6 +190,7 @@ def calculate_global_norm_from_list(norm_list):
total_norm += norm**2.0 total_norm += norm**2.0
return math.sqrt(total_norm) return math.sqrt(total_norm)
def sync_tensor(flat_tensor, tensor_list): def sync_tensor(flat_tensor, tensor_list):
""" """
Synchronize the flattened tensor and unflattened tensor list. When Synchronize the flattened tensor and unflattened tensor list. When

@ -119,9 +119,7 @@ def main():
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
# run pipeline forward backward # run pipeline forward backward
batch = iter([batch]) batch = iter([batch])
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(batch, model, criterion, optimizer, return_loss=True)
batch, model, criterion, optimizer, return_loss=True
)
else: else:
outputs = model(**batch) outputs = model(**batch)
loss = criterion(outputs, None) loss = criterion(outputs, None)

@ -270,9 +270,7 @@ def main():
) as pbar: ) as pbar:
for step in pbar: for step in pbar:
if use_pipeline: if use_pipeline:
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
dataloader_iter, model, _criterion, optimizer, return_loss=True
)
loss = outputs["loss"] loss = outputs["loss"]
else: else:
batch = next(dataloader_iter) batch = next(dataloader_iter)

@ -285,9 +285,7 @@ def main():
) as pbar: ) as pbar:
for step in pbar: for step in pbar:
if use_pipeline: if use_pipeline:
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
dataloader_iter, model, _criterion, optimizer, return_loss=True
)
loss = outputs["loss"] loss = outputs["loss"]
else: else:
batch = next(dataloader_iter) batch = next(dataloader_iter)

@ -50,7 +50,6 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer: class Timer:
def __init__(self) -> None: def __init__(self) -> None:
self.start_time: Optional[float] = None self.start_time: Optional[float] = None
self.duration: float = 0.0 self.duration: float = 0.0
@ -112,7 +111,7 @@ class PerformanceEvaluator:
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape
self.num_samples += batch_size self.num_samples += batch_size
self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))) self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
def on_fit_end(self) -> None: def on_fit_end(self) -> None:
avg_duration = all_reduce_mean(self.timer.duration, self.world_size) avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
@ -122,5 +121,6 @@ class PerformanceEvaluator:
if dist.get_rank() == 0: if dist.get_rank() == 0:
print( print(
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, " f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
f"avg_throughput: {avg_throughput}") f"avg_throughput: {avg_throughput}"
)
print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}") print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")

@ -16,17 +16,15 @@ def inference(args):
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
if args.model == "test": if args.model == "test":
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base") config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")
set_openmoe_args(config, set_openmoe_args(
num_experts=config.num_experts, config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=True
moe_layer_interval=config.moe_layer_interval, )
enable_kernel=True)
model = OpenMoeForCausalLM(config) model = OpenMoeForCausalLM(config)
else: else:
config = LlamaConfig.from_pretrained(f"hpcai-tech/openmoe-{args.model}") config = LlamaConfig.from_pretrained(f"hpcai-tech/openmoe-{args.model}")
set_openmoe_args(config, set_openmoe_args(
num_experts=config.num_experts, config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=False
moe_layer_interval=config.moe_layer_interval, )
enable_kernel=False)
model = OpenMoeForCausalLM.from_pretrained(f"hpcai-tech/openmoe-{args.model}", config=config) model = OpenMoeForCausalLM.from_pretrained(f"hpcai-tech/openmoe-{args.model}", config=config)
model = model.eval().bfloat16() model = model.eval().bfloat16()
model = model.to(torch.cuda.current_device()) model = model.to(torch.cuda.current_device())

@ -172,9 +172,9 @@ def make_state_dict(converted_params):
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path): def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
"""Replaces the params in model witht the T5X converted params.""" """Replaces the params in model witht the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(variables, converted = convert_t5x_to_pytorch(
num_layers=config.num_hidden_layers, variables, num_layers=config.num_hidden_layers, moe_interval=config.moe_layer_interval
moe_interval=config.moe_layer_interval) )
state_dict = make_state_dict(converted) state_dict = make_state_dict(converted)
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
@ -203,11 +203,9 @@ def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.") parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
# Required parameters # Required parameters
parser.add_argument("--t5x_checkpoint_path", parser.add_argument(
default=None, "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path to the T5X checkpoint."
type=str, )
required=True,
help="Path to the T5X checkpoint.")
parser.add_argument( parser.add_argument(
"--config_file", "--config_file",
default=None, default=None,
@ -215,10 +213,8 @@ if __name__ == "__main__":
required=True, required=True,
help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.", help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
) )
parser.add_argument("--pytorch_dump_path", parser.add_argument(
default=None, "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
type=str, )
required=True,
help="Path to the output PyTorch model.")
args = parser.parse_args() args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path) convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)

@ -41,9 +41,7 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b
# Forward pass # Forward pass
for _ in pbar: for _ in pbar:
if use_pipeline: if use_pipeline:
outputs = booster.execute_pipeline( outputs = booster.execute_pipeline(dataloader, model, _criterion, optimizer, return_loss=True)
dataloader, model, _criterion, optimizer, return_loss=True
)
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_stage:
loss = outputs["loss"] loss = outputs["loss"]

@ -1,5 +1,4 @@
from .cpu_adam_arm import CpuAdamArmExtension from .cpu_adam_arm import CpuAdamArmExtension
from .cpu_adam_x86 import CpuAdamX86Extension from .cpu_adam_x86 import CpuAdamX86Extension
__all__ = ['CpuAdamArmExtension', 'CpuAdamX86Extension'] __all__ = ["CpuAdamArmExtension", "CpuAdamX86Extension"]

@ -1,3 +1,3 @@
from .moe_cuda import MoeCudaExtension from .moe_cuda import MoeCudaExtension
__all__ = ['MoeCudaExtension'] __all__ = ["MoeCudaExtension"]

@ -1,3 +1,3 @@
from .fused_optimizer_cuda import FusedOptimizerCudaExtension from .fused_optimizer_cuda import FusedOptimizerCudaExtension
__all__ = ['FusedOptimizerCudaExtension'] __all__ = ["FusedOptimizerCudaExtension"]

@ -1,4 +1,4 @@
from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension
from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension
__all__ = ['ScaledMaskedSoftmaxCudaExtension', 'ScaledUpperTriangleMaskedSoftmaxCudaExtension'] __all__ = ["ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension"]

@ -1,33 +1,33 @@
import os import os
from . import custom, diffusers, timm, torchaudio, torchvision, transformers from . import custom, diffusers, timm, torchaudio, torchvision, transformers
from .executor import run_fwd, run_fwd_bwd from .executor import run_fwd, run_fwd_bwd
from .registry import model_zoo from .registry import model_zoo
# We pick a subset of models for fast testing in order to reduce the total testing time # We pick a subset of models for fast testing in order to reduce the total testing time
COMMON_MODELS = [ COMMON_MODELS = [
'custom_hanging_param_model', "custom_hanging_param_model",
'custom_nested_model', "custom_nested_model",
'custom_repeated_computed_layers', "custom_repeated_computed_layers",
'custom_simple_net', "custom_simple_net",
'diffusers_clip_text_model', "diffusers_clip_text_model",
'diffusers_auto_encoder_kl', "diffusers_auto_encoder_kl",
'diffusers_unet2d_model', "diffusers_unet2d_model",
'timm_densenet', "timm_densenet",
'timm_resnet', "timm_resnet",
'timm_swin_transformer', "timm_swin_transformer",
'torchaudio_wav2vec2_base', "torchaudio_wav2vec2_base",
'torchaudio_conformer', "torchaudio_conformer",
'transformers_bert_for_masked_lm', "transformers_bert_for_masked_lm",
'transformers_bloom_for_causal_lm', "transformers_bloom_for_causal_lm",
'transformers_falcon_for_causal_lm', "transformers_falcon_for_causal_lm",
'transformers_chatglm_for_conditional_generation', "transformers_chatglm_for_conditional_generation",
'transformers_llama_for_casual_lm', "transformers_llama_for_casual_lm",
'transformers_vit_for_masked_image_modeling', "transformers_vit_for_masked_image_modeling",
'transformers_mistral_for_casual_lm' "transformers_mistral_for_casual_lm",
] ]
IS_FAST_TEST = os.environ.get('FAST_TEST', '0') == '1' IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", 'COMMON_MODELS', 'IS_FAST_TEST']
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", "COMMON_MODELS", "IS_FAST_TEST"]

@ -2,6 +2,7 @@ import torch
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
from ..registry import ModelAttribute, model_zoo from ..registry import ModelAttribute, model_zoo
# ================================ # ================================

@ -74,9 +74,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
data = data_gen_fn() data = data_gen_fn()
model.train() model.train()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
booster.execute_pipeline( booster.execute_pipeline(_preprocess_data(data), model, _criterion, optimizer, return_loss=True)
_preprocess_data(data), model, _criterion, optimizer, return_loss=True
)
else: else:
output = model(**_preprocess_data(data)) output = model(**_preprocess_data(data))
loss = criterion(output) loss = criterion(output)
@ -108,9 +106,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
data_for_shard = data_gen_fn() data_for_shard = data_gen_fn()
data_for_origin = data_gen_fn() data_for_origin = data_gen_fn()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
booster.execute_pipeline( booster.execute_pipeline(_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True)
_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True
)
booster.execute_pipeline( booster.execute_pipeline(
_preprocess_data(data_for_origin), _preprocess_data(data_for_origin),
new_model, new_model,

@ -113,6 +113,7 @@ def check_torch_fsdp_ckpt():
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
import copy import copy
sharded_osd = copy.deepcopy(full_osd) sharded_osd = copy.deepcopy(full_osd)
run_model() run_model()

@ -1,16 +1,8 @@
import math
import time
import numpy as np
import pytest import pytest
import torch import torch
import torch.nn as nn
import transformers
from packaging import version from packaging import version
try: try:
import triton
import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
@ -22,6 +14,7 @@ try:
from exllama_kernels import prepare_buffers, set_tuning_params from exllama_kernels import prepare_buffers, set_tuning_params
from colossalai.inference.quant.gptq import CaiQuantLinear from colossalai.inference.quant.gptq import CaiQuantLinear
HAS_AUTO_GPTQ = True HAS_AUTO_GPTQ = True
except: except:
HAS_AUTO_GPTQ = False HAS_AUTO_GPTQ = False
@ -32,13 +25,14 @@ import warnings
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
try: try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load() gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True HAS_GPTQ_CUDA = True
except ImportError: except ImportError:
warnings.warn('CUDA gptq is not installed') warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False HAS_GPTQ_CUDA = False
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
max_inner_outer_dim = 1 max_inner_outer_dim = 1
max_input_len = 1 max_input_len = 1
@ -64,9 +58,9 @@ def init_buffer(cai_linear, use_act_order=False):
max_input_len = 4096 max_input_len = 4096
# The temp_state buffer is required to reorder X in the act-order case. # The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim), gptq_temp_state_buffer = torch.zeros(
dtype=torch.float16, (max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
device=torch.cuda.current_device()) )
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()) gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer) gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
@ -77,10 +71,11 @@ def init_buffer(cai_linear, use_act_order=False):
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, @pytest.mark.skipif(
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq",
)
def test_gptq_linear(): def test_gptq_linear():
infeature = 1024 infeature = 1024
outfeature = 1024 outfeature = 1024
group_size = 128 group_size = 128
@ -120,7 +115,7 @@ def test_gptq_linear():
max_input_len = 2048 max_input_len = 2048
buffers = { buffers = {
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device),
} }
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
@ -146,5 +141,4 @@ def test_gptq_linear():
if __name__ == "__main__": if __name__ == "__main__":
test_gptq_linear() test_gptq_linear()

@ -1,5 +1,5 @@
import torch
import pytest import pytest
import torch
from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.testing import clear_cache_before_run, parameterize from colossalai.testing import clear_cache_before_run, parameterize
@ -17,6 +17,7 @@ def check_params_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert torch.allclose(p, torch_p, atol=1e-3), f"diff: {torch.abs(p - torch_p)}" assert torch.allclose(p, torch_p, atol=1e-3), f"diff: {torch.abs(p - torch_p)}"
# TODO Something wrong with ci when running this test. # TODO Something wrong with ci when running this test.
@pytest.mark.skip(reason="skip because of something wrong with CI") @pytest.mark.skip(reason="skip because of something wrong with CI")
@clear_cache_before_run() @clear_cache_before_run()

@ -103,9 +103,7 @@ def run_pp(
torch_loss = criterion(torch_output) torch_loss = criterion(torch_output)
torch_loss.backward() torch_loss.backward()
pp_ret = schedule.forward_backward_step( pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True)
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
# check loss # check loss
if stage_manager.is_last_stage(ignore_chunk=True): if stage_manager.is_last_stage(ignore_chunk=True):

@ -99,9 +99,7 @@ def examine_pp(num_microbatch: int, batch_size: int):
torch_output = torch_model(input_list[0]) torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output) torch_loss = criterion(torch_output)
torch_loss.backward() torch_loss.backward()
pp_ret = schedule.forward_backward_step( pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True)
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
# check loss # check loss
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():

Loading…
Cancel
Save