mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cipull/5842/head
parent
4c69e2dc91
commit
df612434c9
|
@ -8,11 +8,10 @@ import argparse
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
|
|
|
@ -10,8 +10,8 @@ import os
|
|||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
from copy import deepcopy
|
||||
from typing import Optional, List, Dict, Tuple, Callable, Any
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.utils import logging
|
||||
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -48,6 +47,7 @@ def get_prompt_template(
|
|||
prompt += f"{role}: <s>"
|
||||
return prompt
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def streaming_chat(
|
||||
model: Any,
|
||||
|
@ -99,14 +99,14 @@ def streaming_chat(
|
|||
logits_processor = LogitsProcessorList()
|
||||
|
||||
generation_kwargs = {
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
'top_k': top_k,
|
||||
'do_sample': do_sample,
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'length_penalty': length_penalty,
|
||||
'use_cache': True,
|
||||
**kwargs
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"do_sample": do_sample,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"length_penalty": length_penalty,
|
||||
"use_cache": True,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
prompt_str = get_prompt_template(input_query, history=history, roles=roles)
|
||||
|
@ -116,9 +116,14 @@ def streaming_chat(
|
|||
history.append({"role": roles[1], "message": input_query.strip()})
|
||||
history.append({"role": roles[2], "message": None})
|
||||
|
||||
for outputs in stream_generate(model, **inputs, past_key_values=past_key_values,
|
||||
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
|
||||
**generation_kwargs):
|
||||
for outputs in stream_generate(
|
||||
model,
|
||||
**inputs,
|
||||
past_key_values=past_key_values,
|
||||
eos_token_id=eos_token_id,
|
||||
return_past_key_values=return_past_key_values,
|
||||
**generation_kwargs,
|
||||
):
|
||||
if return_past_key_values:
|
||||
outputs, past_key_values = outputs
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs):
|
|||
model.to(device)
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
|
||||
except OSError:
|
||||
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
|
||||
|
||||
|
|
|
@ -12,4 +12,3 @@ flash-attn>=2.0.0,<=2.0.5
|
|||
tqdm
|
||||
sentencepiece==0.1.99
|
||||
protobuf<=3.20.0
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from colossal_llama2.utils.stream_chat_patch import streaming_chat
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
|
||||
|
||||
def main(args):
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
||||
|
@ -27,7 +27,11 @@ def main(args):
|
|||
print(f"\n{roles[2]}: ", end="")
|
||||
gen_len = 0
|
||||
for response, history, past_key_values in streaming_chat(
|
||||
model, tokenizer, input_query, history=history, roles=roles,
|
||||
model,
|
||||
tokenizer,
|
||||
input_query,
|
||||
history=history,
|
||||
roles=roles,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
|
@ -35,21 +39,22 @@ def main(args):
|
|||
length_penalty=args.length_penalty,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
past_key_values=past_key_values,
|
||||
return_past_key_values=True):
|
||||
|
||||
return_past_key_values=True,
|
||||
):
|
||||
output = response[gen_len:]
|
||||
print(output, end="", flush=True)
|
||||
gen_len = len(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_path', type=str, default=None, help="path to chat version model")
|
||||
parser.add_argument('--tokenizer_path', type=str, default=None, help="path to chat version tokenizer")
|
||||
parser.add_argument('--temperature', type=float, default=0.8, help="set temperature")
|
||||
parser.add_argument('--top_p', type=float, default=0.95, help="set top p value")
|
||||
parser.add_argument('--top_k', type=int, default=50, help="set top k value")
|
||||
parser.add_argument('--do_sample', type=bool, default=True, help="whether turn on do_sample or not")
|
||||
parser.add_argument('--length_penalty', type=float, default=1.2, help="set length penalty")
|
||||
parser.add_argument('--max_new_tokens', type=int, default=512, help="set max new tokens")
|
||||
parser.add_argument("--model_path", type=str, default=None, help="path to chat version model")
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="path to chat version tokenizer")
|
||||
parser.add_argument("--temperature", type=float, default=0.8, help="set temperature")
|
||||
parser.add_argument("--top_p", type=float, default=0.95, help="set top p value")
|
||||
parser.add_argument("--top_k", type=int, default=50, help="set top k value")
|
||||
parser.add_argument("--do_sample", type=bool, default=True, help="whether turn on do_sample or not")
|
||||
parser.add_argument("--length_penalty", type=float, default=1.2, help="set length penalty")
|
||||
parser.add_argument("--max_new_tokens", type=int, default=512, help="set max new tokens")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -20,13 +20,13 @@ import colossalai
|
|||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
|
|
|
@ -3,7 +3,6 @@ import copy
|
|||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_eval import dataset, models, utils
|
||||
|
||||
|
|
Binary file not shown.
|
@ -106,6 +106,5 @@ def main():
|
|||
print(f"[{coordinator.rank}] {outputs}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -24,6 +24,7 @@ from langchain.pydantic_v1 import Field
|
|||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class CustomBaseRetrievalQA(BaseRetrievalQA):
|
||||
"""Base class for question-answering chains."""
|
||||
|
||||
|
@ -98,7 +99,6 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
|
|||
for k, v in inputs.items()
|
||||
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
|
||||
}
|
||||
answers = []
|
||||
if self.combine_documents_chain.memory is not None:
|
||||
buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
|
||||
self.combine_documents_chain.memory.buffered_history
|
||||
|
@ -117,10 +117,10 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
|
|||
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
|
||||
|
||||
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
||||
rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', [])
|
||||
rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
|
||||
answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) else None
|
||||
if answer is None:
|
||||
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
|
||||
answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
|
||||
if self.combine_documents_chain.memory is not None:
|
||||
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
||||
|
||||
|
@ -161,10 +161,14 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
|
|||
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
|
||||
)
|
||||
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
||||
rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', [])
|
||||
answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords)==0 else None
|
||||
rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
|
||||
answer = (
|
||||
answer
|
||||
if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords) == 0
|
||||
else None
|
||||
)
|
||||
if answer is None:
|
||||
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
|
||||
answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
|
||||
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
||||
|
||||
if self.return_source_documents:
|
||||
|
|
|
@ -1,32 +1,33 @@
|
|||
'''
|
||||
"""
|
||||
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import glob
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import create_engine
|
||||
from colossalqa.utils import drop_table
|
||||
from colossalqa.mylogging import get_logger
|
||||
from colossalqa.utils import drop_table
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
SUPPORTED_DATA_FORMAT = ['.csv','.xlsx', '.xls','.json','.html','.h5', '.hdf5','.parquet','.feather','.dta']
|
||||
SUPPORTED_DATA_FORMAT = [".csv", ".xlsx", ".xls", ".json", ".html", ".h5", ".hdf5", ".parquet", ".feather", ".dta"]
|
||||
|
||||
|
||||
class TableLoader:
|
||||
'''
|
||||
"""
|
||||
Load tables from different files and serve a sql database for database operations
|
||||
'''
|
||||
def __init__(self, files: str,
|
||||
sql_path:str='sqlite:///mydatabase.db',
|
||||
verbose=False, **kwargs) -> None:
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, files: str, sql_path: str = "sqlite:///mydatabase.db", verbose=False, **kwargs) -> None:
|
||||
"""
|
||||
Args:
|
||||
files: list of files (list[file path, name])
|
||||
sql_path: how to serve the sql database
|
||||
**kwargs: keyword type arguments, useful for certain document types
|
||||
'''
|
||||
"""
|
||||
self.data = {}
|
||||
self.verbose = verbose
|
||||
self.sql_path = sql_path
|
||||
|
@ -49,10 +50,10 @@ class TableLoader:
|
|||
self.to_sql(path, dataset_name)
|
||||
|
||||
def load_data(self, path):
|
||||
'''
|
||||
"""
|
||||
Load data and serve the data as sql database.
|
||||
Data must be in pandas format
|
||||
'''
|
||||
"""
|
||||
files = []
|
||||
# Handle glob expression
|
||||
try:
|
||||
|
@ -67,40 +68,40 @@ class TableLoader:
|
|||
for file in files:
|
||||
self.load_data(file)
|
||||
|
||||
if path.endswith('.csv'):
|
||||
if path.endswith(".csv"):
|
||||
# Load csv
|
||||
self.data[path] = pd.read_csv(path)
|
||||
elif path.endswith('.xlsx') or path.endswith('.xls'):
|
||||
elif path.endswith(".xlsx") or path.endswith(".xls"):
|
||||
# Load excel
|
||||
self.data[path] = pd.read_excel(path) # You can adjust the sheet_name as needed
|
||||
elif path.endswith('.json'):
|
||||
elif path.endswith(".json"):
|
||||
# Load json
|
||||
self.data[path] = pd.read_json(path)
|
||||
elif path.endswith('.html'):
|
||||
elif path.endswith(".html"):
|
||||
# Load html
|
||||
html_tables = pd.read_html(path)
|
||||
# Choose the desired table from the list of DataFrame objects
|
||||
self.data[path] = html_tables[0] # You may need to adjust this index
|
||||
elif path.endswith('.h5') or path.endswith('.hdf5'):
|
||||
elif path.endswith(".h5") or path.endswith(".hdf5"):
|
||||
# Load h5
|
||||
self.data[path] = pd.read_hdf(path, key=self.kwargs.get('key', 'data')) # You can adjust the key as needed
|
||||
elif path.endswith('.parquet'):
|
||||
self.data[path] = pd.read_hdf(path, key=self.kwargs.get("key", "data")) # You can adjust the key as needed
|
||||
elif path.endswith(".parquet"):
|
||||
# Load parquet
|
||||
self.data[path] = pd.read_parquet(path, engine='fastparquet')
|
||||
elif path.endswith('.feather'):
|
||||
self.data[path] = pd.read_parquet(path, engine="fastparquet")
|
||||
elif path.endswith(".feather"):
|
||||
# Load feather
|
||||
self.data[path] = pd.read_feather(path)
|
||||
elif path.endswith('.dta'):
|
||||
elif path.endswith(".dta"):
|
||||
# Load dta
|
||||
self.data[path] = pd.read_stata(path)
|
||||
else:
|
||||
raise ValueError("Unsupported file format")
|
||||
|
||||
def to_sql(self, path, table_name):
|
||||
'''
|
||||
"""
|
||||
Serve the data as sql database.
|
||||
'''
|
||||
self.data[path].to_sql(table_name, con=self.sql_engine, if_exists='replace', index=False)
|
||||
"""
|
||||
self.data[path].to_sql(table_name, con=self.sql_engine, if_exists="replace", index=False)
|
||||
logger.info(f"Loaded to Sqlite3\nPath: {path}", verbose=self.verbose)
|
||||
return self.sql_path
|
||||
|
||||
|
@ -113,7 +114,3 @@ class TableLoader:
|
|||
self.sql_engine.dispose()
|
||||
del self.data
|
||||
del self.sql_engine
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ print(resp) # super-heavyweight awesome-natured yawning Australian creature!
|
|||
|
||||
"""
|
||||
import json
|
||||
from typing import Any, List, Mapping, Optional
|
||||
from typing import Any, Mapping
|
||||
|
||||
import requests
|
||||
from langchain.llms.base import LLM
|
||||
|
@ -33,11 +33,11 @@ class ColossalCloudLLM(LLM):
|
|||
A custom LLM class that integrates LLMs running on the ColossalCloud Platform
|
||||
|
||||
"""
|
||||
|
||||
n: int
|
||||
gen_config: dict = None
|
||||
auth_config: dict = None
|
||||
valid_gen_para: list = ['max_new_tokens', 'top_k',
|
||||
'top_p', 'temperature', 'repetition_penalty']
|
||||
valid_gen_para: list = ["max_new_tokens", "top_k", "top_p", "temperature", "repetition_penalty"]
|
||||
|
||||
def __init__(self, gen_config=None, **kwargs):
|
||||
"""
|
||||
|
@ -63,15 +63,15 @@ class ColossalCloudLLM(LLM):
|
|||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return 'ColossalCloudLLM'
|
||||
return "ColossalCloudLLM"
|
||||
|
||||
def set_auth_config(self, **kwargs):
|
||||
url = get_from_dict_or_env(kwargs, "url", "URL")
|
||||
host = get_from_dict_or_env(kwargs, "host", "HOST")
|
||||
|
||||
auth_config = {}
|
||||
auth_config['endpoint'] = url
|
||||
auth_config['Host'] = host
|
||||
auth_config["endpoint"] = url
|
||||
auth_config["Host"] = host
|
||||
self.auth_config = auth_config
|
||||
|
||||
def _call(self, prompt: str, stop=None, **kwargs: Any) -> str:
|
||||
|
@ -86,7 +86,9 @@ class ColossalCloudLLM(LLM):
|
|||
# Update the generation arguments
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.valid_gen_para:
|
||||
raise KeyError(f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}")
|
||||
raise KeyError(
|
||||
f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}"
|
||||
)
|
||||
if key in self.gen_config:
|
||||
self.gen_config[key] = value
|
||||
|
||||
|
@ -98,26 +100,16 @@ class ColossalCloudLLM(LLM):
|
|||
resp_text = resp_text.split(stopping_words)[0]
|
||||
return resp_text
|
||||
|
||||
|
||||
def text_completion(self, prompt, gen_config, auth_config):
|
||||
# Complusory Parameters
|
||||
endpoint = auth_config.pop('endpoint')
|
||||
max_new_tokens = gen_config.pop('max_new_tokens')
|
||||
endpoint = auth_config.pop("endpoint")
|
||||
max_new_tokens = gen_config.pop("max_new_tokens")
|
||||
# Optional Parameters
|
||||
optional_params = ['top_k', 'top_p', 'temperature', 'repetition_penalty'] # Self.optional
|
||||
optional_params = ["top_k", "top_p", "temperature", "repetition_penalty"] # Self.optional
|
||||
gen_config = {key: gen_config[key] for key in optional_params if key in gen_config}
|
||||
# Define the data payload
|
||||
data = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"history": [
|
||||
{"instruction": prompt, "response": ""}
|
||||
],
|
||||
**gen_config
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**auth_config # 'Host',
|
||||
}
|
||||
data = {"max_new_tokens": max_new_tokens, "history": [{"instruction": prompt, "response": ""}], **gen_config}
|
||||
headers = {"Content-Type": "application/json", **auth_config} # 'Host',
|
||||
# Make the POST request
|
||||
response = requests.post(endpoint, headers=headers, data=json.dumps(data))
|
||||
response.raise_for_status() # raise error if return code is not 200(success)
|
||||
|
|
|
@ -193,4 +193,3 @@ class VllmLLM(LLM):
|
|||
def _identifying_params(self) -> Mapping[str, int]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"n": self.n}
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ All custom prompt templates are defined here.
|
|||
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
# Below are Chinese retrieval qa prompts
|
||||
|
||||
_CUSTOM_SUMMARIZER_TEMPLATE_ZH = """请递进式地总结所提供的当前对话,将当前对话的摘要内容添加到先前已有的摘要上,返回一个融合了当前对话的新的摘要。
|
||||
|
|
|
@ -99,13 +99,7 @@ class CustomRetriever(BaseRetriever):
|
|||
def clear_documents(self):
|
||||
"""Clear all document vectors from database"""
|
||||
for source in self.vector_stores:
|
||||
index(
|
||||
[],
|
||||
self.record_managers[source],
|
||||
self.vector_stores[source],
|
||||
cleanup="full",
|
||||
source_id_key="source"
|
||||
)
|
||||
index([], self.record_managers[source], self.vector_stores[source], cleanup="full", source_id_key="source")
|
||||
self.vector_stores = {}
|
||||
self.sql_index_database = {}
|
||||
self.record_managers = {}
|
||||
|
|
|
@ -1,22 +1,27 @@
|
|||
import argparse
|
||||
|
||||
from colossalqa.retrieval_conversation_universal import UniversalRetrievalConversation
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# Parse arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--en_model_path', type=str, default=None)
|
||||
parser.add_argument('--zh_model_path', type=str, default=None)
|
||||
parser.add_argument('--zh_model_name', type=str, default=None)
|
||||
parser.add_argument('--en_model_name', type=str, default=None)
|
||||
parser.add_argument('--sql_file_path', type=str, default=None, help='path to the a empty folder for storing sql files for indexing')
|
||||
parser.add_argument("--en_model_path", type=str, default=None)
|
||||
parser.add_argument("--zh_model_path", type=str, default=None)
|
||||
parser.add_argument("--zh_model_name", type=str, default=None)
|
||||
parser.add_argument("--en_model_name", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--sql_file_path", type=str, default=None, help="path to the a empty folder for storing sql files for indexing"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Will ask for documents path in running time
|
||||
session = UniversalRetrievalConversation(files_en=None,
|
||||
session = UniversalRetrievalConversation(
|
||||
files_en=None,
|
||||
files_zh=None,
|
||||
zh_model_path=args.zh_model_path, en_model_path=args.en_model_path,
|
||||
zh_model_name=args.zh_model_name, en_model_name=args.en_model_name,
|
||||
sql_file_path=args.sql_file_path
|
||||
zh_model_path=args.zh_model_path,
|
||||
en_model_path=args.en_model_path,
|
||||
zh_model_name=args.zh_model_name,
|
||||
en_model_name=args.en_model_name,
|
||||
sql_file_path=args.sql_file_path,
|
||||
)
|
||||
session.start_test_session()
|
||||
|
|
@ -5,13 +5,7 @@ from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
|||
from colossalqa.data_loader.document_loader import DocumentLoader
|
||||
from colossalqa.memory import ConversationBufferWithSummary
|
||||
from colossalqa.mylogging import get_logger
|
||||
from colossalqa.prompt.prompt import (
|
||||
PROMPT_DISAMBIGUATE_ZH,
|
||||
PROMPT_RETRIEVAL_QA_ZH,
|
||||
SUMMARY_PROMPT_ZH,
|
||||
ZH_RETRIEVAL_QA_REJECTION_ANSWER,
|
||||
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
|
||||
)
|
||||
from colossalqa.prompt.prompt import ZH_RETRIEVAL_QA_REJECTION_ANSWER, ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS
|
||||
from colossalqa.retriever import CustomRetriever
|
||||
from langchain import LLMChain
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
|
|
@ -1,58 +1,30 @@
|
|||
from colossalqa.prompt.prompt import (
|
||||
PROMPT_DISAMBIGUATE_ZH,
|
||||
PROMPT_RETRIEVAL_QA_ZH,
|
||||
SUMMARY_PROMPT_ZH,
|
||||
ZH_RETRIEVAL_QA_REJECTION_ANSWER,
|
||||
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
|
||||
)
|
||||
from colossalqa.prompt.prompt import PROMPT_DISAMBIGUATE_ZH, PROMPT_RETRIEVAL_QA_ZH, SUMMARY_PROMPT_ZH
|
||||
from colossalqa.text_splitter import ChineseTextSplitter
|
||||
|
||||
ALL_CONFIG = {
|
||||
"embed": {
|
||||
"embed_name": "m3e", # embedding model name
|
||||
"embed_model_name_or_path": "moka-ai/m3e-base", # path to embedding model, could be a local path or a huggingface path
|
||||
"embed_model_device": {
|
||||
"device": "cpu"
|
||||
}
|
||||
"embed_model_device": {"device": "cpu"},
|
||||
},
|
||||
"model": {
|
||||
"mode": "api", # "local" for loading models, "api" for using model api
|
||||
"model_name": "chatgpt_api", # local model name, "chatgpt_api" or "pangu_api"
|
||||
"model_path": "", # path to the model, could be a local path or a huggingface path. don't need if using an api
|
||||
"device": {
|
||||
"device": "cuda"
|
||||
}
|
||||
},
|
||||
"splitter": {
|
||||
"name": ChineseTextSplitter
|
||||
},
|
||||
"retrieval": {
|
||||
"retri_top_k": 3,
|
||||
"retri_kb_file_path": "./", # path to store database files
|
||||
"verbose": True
|
||||
"device": {"device": "cuda"},
|
||||
},
|
||||
"splitter": {"name": ChineseTextSplitter},
|
||||
"retrieval": {"retri_top_k": 3, "retri_kb_file_path": "./", "verbose": True}, # path to store database files
|
||||
"chain": {
|
||||
"mem_summary_prompt": SUMMARY_PROMPT_ZH, # summary prompt template
|
||||
"mem_human_prefix": "用户",
|
||||
"mem_ai_prefix": "Assistant",
|
||||
"mem_max_tokens": 2000,
|
||||
"mem_llm_kwargs": {
|
||||
"max_new_tokens": 50,
|
||||
"temperature": 1,
|
||||
"do_sample": True
|
||||
},
|
||||
"mem_llm_kwargs": {"max_new_tokens": 50, "temperature": 1, "do_sample": True},
|
||||
"disambig_prompt": PROMPT_DISAMBIGUATE_ZH, # disambiguate prompt template
|
||||
"disambig_llm_kwargs": {
|
||||
"max_new_tokens": 30,
|
||||
"temperature": 1,
|
||||
"do_sample": True
|
||||
},
|
||||
"gen_llm_kwargs": {
|
||||
"max_new_tokens": 100,
|
||||
"temperature": 1,
|
||||
"do_sample": True
|
||||
},
|
||||
"disambig_llm_kwargs": {"max_new_tokens": 30, "temperature": 1, "do_sample": True},
|
||||
"gen_llm_kwargs": {"max_new_tokens": 100, "temperature": 1, "do_sample": True},
|
||||
"gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH, # generation prompt template
|
||||
"verbose": True
|
||||
}
|
||||
"verbose": True,
|
||||
},
|
||||
}
|
|
@ -1,27 +1,18 @@
|
|||
import argparse
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
from colossalqa.local.llm import ColossalAPI, ColossalLLM
|
||||
from colossalqa.data_loader.document_loader import DocumentLoader
|
||||
from colossalqa.mylogging import get_logger
|
||||
from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
|
||||
from colossalqa.retriever import CustomRetriever
|
||||
from enum import Enum
|
||||
from fastapi import FastAPI, Request
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from pydantic import BaseModel, Field
|
||||
import uvicorn
|
||||
|
||||
import config
|
||||
import uvicorn
|
||||
from colossalqa.local.llm import ColossalAPI, ColossalLLM
|
||||
from colossalqa.mylogging import get_logger
|
||||
from fastapi import FastAPI, Request
|
||||
from pydantic import BaseModel
|
||||
from RAG_ChatBot import RAG_ChatBot
|
||||
from utils import DocAction
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def parseArgs():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--http_host", default="0.0.0.0")
|
||||
|
@ -36,6 +27,7 @@ class DocUpdateReq(BaseModel):
|
|||
doc_files: Union[List[str], str, None] = None
|
||||
action: DocAction = DocAction.ADD
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
user_input: str
|
||||
|
||||
|
@ -84,12 +76,13 @@ if __name__ == "__main__":
|
|||
"user": "User",
|
||||
"max_tokens": all_config["chain"]["disambig_llm_kwargs"]["max_new_tokens"],
|
||||
"temperature": all_config["chain"]["disambig_llm_kwargs"]["temperature"],
|
||||
"n": 1 # the number of responses generated
|
||||
"n": 1, # the number of responses generated
|
||||
}
|
||||
llm = Pangu(gen_config=gen_config)
|
||||
llm.set_auth_config() # verify user's auth info here
|
||||
elif model_name == "chatgpt_api":
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
llm = OpenAI()
|
||||
else:
|
||||
raise ValueError("Unsupported mode.")
|
||||
|
|
|
@ -1,24 +1,26 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import requests
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import requests
|
||||
from utils import DocAction
|
||||
|
||||
|
||||
def parseArgs():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--http_host", default="0.0.0.0")
|
||||
parser.add_argument("--http_port", type=int, default=13666)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_response(data, url):
|
||||
headers = {"Content-type": "application/json"}
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
response = json.loads(response.content)
|
||||
return response
|
||||
|
||||
|
||||
def add_text(history, text):
|
||||
history = history + [(text, None)]
|
||||
return history, gr.update(value=None, interactive=True)
|
||||
|
@ -28,18 +30,14 @@ def add_file(history, files):
|
|||
files_string = "\n".join([os.path.basename(file.name) for file in files])
|
||||
|
||||
doc_files = [file.name for file in files]
|
||||
data = {
|
||||
"doc_files": doc_files,
|
||||
"action": DocAction.ADD
|
||||
}
|
||||
data = {"doc_files": doc_files, "action": DocAction.ADD}
|
||||
response = get_response(data, update_url)["response"]
|
||||
history = history + [(files_string, response)]
|
||||
return history
|
||||
|
||||
|
||||
def bot(history):
|
||||
data = {
|
||||
"user_input": history[-1][0].strip()
|
||||
}
|
||||
data = {"user_input": history[-1][0].strip()}
|
||||
response = get_response(data, gen_url)
|
||||
|
||||
if response["error"] != "":
|
||||
|
@ -51,11 +49,8 @@ def bot(history):
|
|||
|
||||
def restart(chatbot, txt):
|
||||
# Reset the conversation state and clear the chat history
|
||||
data = {
|
||||
"doc_files": "",
|
||||
"action": DocAction.CLEAR
|
||||
}
|
||||
response = get_response(data, update_url)
|
||||
data = {"doc_files": "", "action": DocAction.CLEAR}
|
||||
get_response(data, update_url)
|
||||
|
||||
return gr.update(value=None), gr.update(value=None, interactive=True)
|
||||
|
||||
|
|
|
@ -1,21 +1,21 @@
|
|||
import os
|
||||
|
||||
from colossalqa.data_loader.document_loader import DocumentLoader
|
||||
|
||||
|
||||
def test_add_document():
|
||||
PATH = os.environ.get('TEST_DOCUMENT_LOADER_DATA_PATH')
|
||||
files = [[PATH, 'all data']]
|
||||
PATH = os.environ.get("TEST_DOCUMENT_LOADER_DATA_PATH")
|
||||
files = [[PATH, "all data"]]
|
||||
document_loader = DocumentLoader(files)
|
||||
documents = document_loader.all_data
|
||||
all_files = []
|
||||
for doc in documents:
|
||||
assert isinstance(doc.page_content, str) == True
|
||||
if doc.metadata['source'] not in all_files:
|
||||
all_files.append(doc.metadata['source'])
|
||||
if doc.metadata["source"] not in all_files:
|
||||
all_files.append(doc.metadata["source"])
|
||||
print(all_files)
|
||||
assert len(all_files) == 6
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
if __name__ == "__main__":
|
||||
test_add_document()
|
||||
|
||||
|
|
|
@ -4,56 +4,44 @@ from colossalqa.retrieval_conversation_universal import UniversalRetrievalConver
|
|||
|
||||
|
||||
def test_en_retrievalQA():
|
||||
data_path_en = os.environ.get('TEST_DATA_PATH_EN')
|
||||
data_path_zh = os.environ.get('TEST_DATA_PATH_ZH')
|
||||
en_model_path = os.environ.get('EN_MODEL_PATH')
|
||||
zh_model_path = os.environ.get('ZH_MODEL_PATH')
|
||||
zh_model_name = os.environ.get('ZH_MODEL_NAME')
|
||||
en_model_name = os.environ.get('EN_MODEL_NAME')
|
||||
sql_file_path = os.environ.get('SQL_FILE_PATH')
|
||||
qa_session = UniversalRetrievalConversation(files_en=[{
|
||||
'data_path': data_path_en,
|
||||
'name': 'company information',
|
||||
'separator': '\n'
|
||||
}],
|
||||
files_zh=[{
|
||||
'data_path': data_path_zh,
|
||||
'name': 'company information',
|
||||
'separator': '\n'
|
||||
}],
|
||||
data_path_en = os.environ.get("TEST_DATA_PATH_EN")
|
||||
data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
|
||||
en_model_path = os.environ.get("EN_MODEL_PATH")
|
||||
zh_model_path = os.environ.get("ZH_MODEL_PATH")
|
||||
zh_model_name = os.environ.get("ZH_MODEL_NAME")
|
||||
en_model_name = os.environ.get("EN_MODEL_NAME")
|
||||
sql_file_path = os.environ.get("SQL_FILE_PATH")
|
||||
qa_session = UniversalRetrievalConversation(
|
||||
files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}],
|
||||
files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}],
|
||||
zh_model_path=zh_model_path,
|
||||
en_model_path=en_model_path,
|
||||
zh_model_name=zh_model_name,
|
||||
en_model_name=en_model_name,
|
||||
sql_file_path=sql_file_path)
|
||||
ans = qa_session.run("which company runs business in hotel industry?", which_language='en')
|
||||
sql_file_path=sql_file_path,
|
||||
)
|
||||
ans = qa_session.run("which company runs business in hotel industry?", which_language="en")
|
||||
print(ans)
|
||||
|
||||
|
||||
def test_zh_retrievalQA():
|
||||
data_path_en = os.environ.get('TEST_DATA_PATH_EN')
|
||||
data_path_zh = os.environ.get('TEST_DATA_PATH_ZH')
|
||||
en_model_path = os.environ.get('EN_MODEL_PATH')
|
||||
zh_model_path = os.environ.get('ZH_MODEL_PATH')
|
||||
zh_model_name = os.environ.get('ZH_MODEL_NAME')
|
||||
en_model_name = os.environ.get('EN_MODEL_NAME')
|
||||
sql_file_path = os.environ.get('SQL_FILE_PATH')
|
||||
qa_session = UniversalRetrievalConversation(files_en=[{
|
||||
'data_path': data_path_en,
|
||||
'name': 'company information',
|
||||
'separator': '\n'
|
||||
}],
|
||||
files_zh=[{
|
||||
'data_path': data_path_zh,
|
||||
'name': 'company information',
|
||||
'separator': '\n'
|
||||
}],
|
||||
data_path_en = os.environ.get("TEST_DATA_PATH_EN")
|
||||
data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
|
||||
en_model_path = os.environ.get("EN_MODEL_PATH")
|
||||
zh_model_path = os.environ.get("ZH_MODEL_PATH")
|
||||
zh_model_name = os.environ.get("ZH_MODEL_NAME")
|
||||
en_model_name = os.environ.get("EN_MODEL_NAME")
|
||||
sql_file_path = os.environ.get("SQL_FILE_PATH")
|
||||
qa_session = UniversalRetrievalConversation(
|
||||
files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}],
|
||||
files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}],
|
||||
zh_model_path=zh_model_path,
|
||||
en_model_path=en_model_path,
|
||||
zh_model_name=zh_model_name,
|
||||
en_model_name=en_model_name,
|
||||
sql_file_path=sql_file_path)
|
||||
ans = qa_session.run("哪家公司在经营酒店业务?", which_language='zh')
|
||||
sql_file_path=sql_file_path,
|
||||
)
|
||||
ans = qa_session.run("哪家公司在经营酒店业务?", which_language="zh")
|
||||
print(ans)
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
|
||||
from . import accelerator
|
||||
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
|
||||
|
||||
try:
|
||||
# .version will be created by setup.py
|
||||
|
|
|
@ -27,7 +27,7 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
|
@ -93,9 +93,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
with FSDP.state_dict_type(
|
||||
model.unwrap(),
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
model.unwrap(), StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
):
|
||||
state_dict = model.unwrap().state_dict()
|
||||
|
||||
|
@ -172,7 +170,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
with FSDP.state_dict_type(
|
||||
optimizer.unwrap_model().unwrap(),
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
||||
):
|
||||
fsdp_optim_state = FSDP.full_optim_state_dict(
|
||||
optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
|
||||
|
@ -241,7 +239,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
)
|
||||
optimizer.load_state_dict(fsdp_state)
|
||||
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
|
|
|
@ -294,6 +294,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
|||
# Helper functions for saving state dict
|
||||
# ======================================
|
||||
|
||||
|
||||
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
|
||||
"""
|
||||
Save state dict to checkpoint.
|
||||
|
|
|
@ -225,4 +225,3 @@ class ProcessGroupMesh:
|
|||
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
||||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
||||
return self._ranks_to_group[ranks_in_group]
|
||||
|
|
@ -29,13 +29,17 @@ except:
|
|||
|
||||
try:
|
||||
from colossalai.kernel.triton.flash_decoding import token_flash_decoding
|
||||
|
||||
HAS_TRITON_FLASH_DECODING_KERNEL = True
|
||||
except:
|
||||
print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
|
||||
print(
|
||||
"no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8"
|
||||
)
|
||||
HAS_TRITON_FLASH_DECODING_KERNEL = False
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
|
||||
HAS_FLASH_KERNEL = True
|
||||
except:
|
||||
HAS_FLASH_KERNEL = False
|
||||
|
@ -48,6 +52,7 @@ def rotate_half(x):
|
|||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
|
@ -96,15 +101,20 @@ def llama_triton_context_attention(
|
|||
infer_state.max_len_in_batch,
|
||||
)
|
||||
|
||||
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
|
||||
|
||||
def llama_triton_token_attention(
|
||||
query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num=-1, head_dim=-1
|
||||
):
|
||||
if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
|
||||
token_flash_decoding(q = query_states,
|
||||
token_flash_decoding(
|
||||
q=query_states,
|
||||
o_tensor=attn_output,
|
||||
infer_state=infer_state,
|
||||
q_head_num=q_head_num,
|
||||
head_dim=head_dim,
|
||||
cache_k=infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id])
|
||||
cache_v=infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
)
|
||||
return
|
||||
|
||||
if num_key_value_groups == 1:
|
||||
|
@ -459,14 +469,15 @@ class LlamaInferenceForwards:
|
|||
)
|
||||
|
||||
if HAS_LIGHTLLM_KERNEL:
|
||||
|
||||
attn_output = torch.empty_like(query_states)
|
||||
llama_triton_token_attention(query_states = query_states,
|
||||
llama_triton_token_attention(
|
||||
query_states=query_states,
|
||||
attn_output=attn_output,
|
||||
infer_state=infer_state,
|
||||
num_key_value_groups=self.num_key_value_groups,
|
||||
q_head_num=q_len * self.num_heads,
|
||||
head_dim = self.head_dim)
|
||||
head_dim=self.head_dim,
|
||||
)
|
||||
else:
|
||||
self.num_heads // self.num_key_value_heads
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
|
||||
|
|
|
@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp
|
|||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn('CUDA gptq is not installed')
|
||||
warnings.warn("CUDA gptq is not installed")
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
|
||||
class CaiQuantLinear(nn.Module):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
super().__init__()
|
||||
if bits not in [2, 4, 8]:
|
||||
|
@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module):
|
|||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||
|
||||
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
||||
self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
||||
self.register_buffer(
|
||||
'qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
|
||||
self.register_buffer('scales',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
|
||||
"qzeros",
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32),
|
||||
)
|
||||
self.register_buffer(
|
||||
"scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
|
||||
)
|
||||
if row_split:
|
||||
self.register_buffer(
|
||||
'g_idx',
|
||||
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
|
||||
dtype=torch.int32))
|
||||
"g_idx",
|
||||
torch.tensor(
|
||||
[(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.register_buffer('g_idx',
|
||||
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
|
||||
self.register_buffer(
|
||||
"g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module):
|
|||
self.row_split = row_split
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
|
||||
g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
|
||||
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
|
||||
g_idx = (
|
||||
g_idx.clone()
|
||||
if g_idx is not None
|
||||
else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
|
||||
)
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
|
@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module):
|
|||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
wn = 8
|
||||
pbits = 32
|
||||
ptype = torch.int32
|
||||
unsign_type = np.uint32
|
||||
|
@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module):
|
|||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
|
||||
None])
|
||||
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[
|
||||
:, None
|
||||
]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(unsign_type)
|
||||
|
@ -144,13 +151,16 @@ class CaiQuantLinear(nn.Module):
|
|||
torch.tensor(
|
||||
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
|
||||
dtype=torch.int32,
|
||||
device=self.g_idx.device)):
|
||||
device=self.g_idx.device,
|
||||
),
|
||||
):
|
||||
self.g_idx = None
|
||||
elif torch.equal(
|
||||
self.g_idx,
|
||||
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
|
||||
dtype=torch.int32,
|
||||
device=self.g_idx.device)):
|
||||
torch.tensor(
|
||||
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device
|
||||
),
|
||||
):
|
||||
self.g_idx = None
|
||||
|
||||
if self.g_idx is not None:
|
||||
|
@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module):
|
|||
outshape = x.shape[:-1] + (self.outfeatures,)
|
||||
|
||||
if HAS_GPTQ_CUDA and self.bits == 4:
|
||||
|
||||
if self.q4 is None:
|
||||
self.init_q4()
|
||||
|
||||
|
@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module):
|
|||
|
||||
|
||||
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
|
||||
|
||||
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
|
||||
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
|
||||
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
|
||||
|
@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1
|
|||
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
|
||||
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||||
cai_split_out_features]
|
||||
cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
|
||||
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
|
||||
cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
|
||||
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||||
cai_split_out_features]
|
||||
cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][
|
||||
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||
]
|
||||
cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][
|
||||
:, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block
|
||||
]
|
||||
cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][
|
||||
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||
]
|
||||
if cai_linear.bias is not None:
|
||||
cai_linear.bias[i * cai_split_out_features:(i + 1) *
|
||||
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||||
cai_split_out_features]
|
||||
cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][
|
||||
tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||
]
|
||||
|
||||
cai_linear.g_idx.copy_(g_idx)
|
||||
|
||||
|
||||
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
||||
|
||||
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
|
||||
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
|
||||
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
|
||||
|
@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
|||
idx_split_features = cai_linear.infeatures // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
cai_linear.qweight[i * cai_split_in_features:(i + 1) *
|
||||
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
|
||||
cai_split_in_features, :]
|
||||
cai_linear.qzeros[i * zero_split_block:(i + 1) *
|
||||
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
||||
zero_split_block, :]
|
||||
cai_linear.scales[i * zero_split_block:(i + 1) *
|
||||
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
||||
zero_split_block, :]
|
||||
cai_linear.g_idx[i * idx_split_features:(i + 1) *
|
||||
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
|
||||
idx_split_features]
|
||||
cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][
|
||||
tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, :
|
||||
]
|
||||
cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][
|
||||
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
|
||||
]
|
||||
cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][
|
||||
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
|
||||
]
|
||||
cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][
|
||||
tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features
|
||||
]
|
||||
if cai_linear.bias is not None:
|
||||
cai_linear.bias.copy_(gptq_linear.bias)
|
||||
|
||||
|
||||
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
|
||||
super().__init__(bits,
|
||||
groupsize,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
row_split=row_split)
|
||||
super().__init__(
|
||||
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
|
||||
)
|
||||
self.process_group = None
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
|
@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
|||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||||
linear_1d = RowCaiQuantLinear(module.bits,
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = RowCaiQuantLinear(
|
||||
module.bits,
|
||||
module.group_size,
|
||||
module.in_features // tp_size,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
row_split=True)
|
||||
row_split=True,
|
||||
)
|
||||
linear_1d.process_group = process_group
|
||||
|
||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
|
@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
|||
|
||||
|
||||
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
|
||||
super().__init__(bits,
|
||||
groupsize,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
row_split=row_split)
|
||||
super().__init__(
|
||||
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
|
||||
)
|
||||
self.process_group = None
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
|
@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
|||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||||
linear_1d = ColCaiQuantLinear(module.bits,
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
linear_1d = ColCaiQuantLinear(
|
||||
module.bits,
|
||||
module.group_size,
|
||||
module.in_features,
|
||||
module.out_features // tp_size,
|
||||
module.bias is not None,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank)
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
linear_1d.process_group = process_group
|
||||
|
||||
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
|
@ -16,6 +17,7 @@ if HAS_TRITON:
|
|||
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||
"""
|
||||
if triton.__version__ < "2.1.0":
|
||||
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel(
|
||||
Q,
|
||||
|
@ -131,23 +133,41 @@ if HAS_TRITON:
|
|||
m_i = m_i_new
|
||||
|
||||
off_o = (
|
||||
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
(cur_batch_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
|
||||
else:
|
||||
# this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel_2(
|
||||
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Alibi,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Out,
|
||||
kv_group_num,
|
||||
stride_qbs, stride_qh, stride_qd,
|
||||
stride_kbs, stride_kh, stride_kd,
|
||||
stride_vbs, stride_vh, stride_vd,
|
||||
stride_obs, stride_oh, stride_od,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
|
@ -166,7 +186,11 @@ if HAS_TRITON:
|
|||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :] * stride_qd
|
||||
)
|
||||
if kv_group_num is None or kv_group_num == 1:
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
|
@ -191,8 +215,11 @@ if HAS_TRITON:
|
|||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
|
@ -220,8 +247,11 @@ if HAS_TRITON:
|
|||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
|
@ -229,7 +259,11 @@ if HAS_TRITON:
|
|||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
|
@ -286,7 +320,13 @@ if HAS_TRITON:
|
|||
)
|
||||
else:
|
||||
_context_flash_attention_kernel_2[grid](
|
||||
q, k, v, sm_scale, alibi, b_start_loc, b_seq_len,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
alibi,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
None,
|
||||
q.stride(0),
|
||||
|
@ -388,6 +428,7 @@ if HAS_TRITON:
|
|||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,)
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
return
|
|
@ -1,8 +1,10 @@
|
|||
# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
|
||||
import torch
|
||||
|
||||
try:
|
||||
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
|
||||
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
|
||||
|
@ -10,31 +12,29 @@ except:
|
|||
|
||||
|
||||
if HAS_LIGHTLLM_KERNEL:
|
||||
|
||||
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
|
||||
BLOCK_SEQ = 256
|
||||
batch_size = infer_state.batch_size
|
||||
max_len_in_batch = infer_state.max_len_in_batch
|
||||
|
||||
|
||||
calcu_shape1 = (batch_size, q_head_num, head_dim)
|
||||
|
||||
if getattr(infer_state, 'mid_o', None) is None:
|
||||
infer_state.mid_o = torch.empty([batch_size,
|
||||
q_head_num,
|
||||
max_len_in_batch // BLOCK_SEQ + 1,
|
||||
head_dim],
|
||||
if getattr(infer_state, "mid_o", None) is None:
|
||||
infer_state.mid_o = torch.empty(
|
||||
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim],
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
infer_state.mid_o_logexpsum = torch.empty([batch_size,
|
||||
q_head_num,
|
||||
max_len_in_batch // BLOCK_SEQ + 1],
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
device="cuda",
|
||||
)
|
||||
infer_state.mid_o_logexpsum = torch.empty(
|
||||
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
mid_o = infer_state.mid_o
|
||||
mid_o_logexpsum = infer_state.mid_o_logexpsum
|
||||
|
||||
flash_decode_stage1(q.view(calcu_shape1),
|
||||
flash_decode_stage1(
|
||||
q.view(calcu_shape1),
|
||||
cache_k,
|
||||
cache_v,
|
||||
infer_state.block_loc,
|
||||
|
@ -42,9 +42,6 @@ if HAS_LIGHTLLM_KERNEL:
|
|||
infer_state.max_len_in_batch,
|
||||
mid_o,
|
||||
mid_o_logexpsum,
|
||||
BLOCK_SEQ)
|
||||
flash_decode_stage2(mid_o,
|
||||
mid_o_logexpsum,
|
||||
infer_state.seq_len,
|
||||
o_tensor.view(calcu_shape1),
|
||||
BLOCK_SEQ)
|
||||
BLOCK_SEQ,
|
||||
)
|
||||
flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
|
||||
|
|
|
@ -8,6 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
|||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
|
@ -41,9 +42,9 @@ if HAS_TRITON:
|
|||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
|
||||
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
||||
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
|
||||
# Write output
|
||||
|
@ -76,10 +77,10 @@ if HAS_TRITON:
|
|||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
|
||||
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
|
||||
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.0)
|
||||
|
||||
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
|
||||
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
||||
|
@ -147,14 +148,9 @@ if HAS_TRITON:
|
|||
# restore setting
|
||||
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
|
||||
# enqueue kernel
|
||||
_llama_act_combine_forward[(M,)](x_gate1,
|
||||
x_gate2,
|
||||
x_up,
|
||||
y,
|
||||
x_up.stride(-2),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps)
|
||||
_llama_act_combine_forward[(M,)](
|
||||
x_gate1, x_gate2, x_up, y, x_up.stride(-2), N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
||||
)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
|
@ -166,11 +162,15 @@ if HAS_TRITON:
|
|||
|
||||
# init grad
|
||||
y_grad = grad_outputs[0]
|
||||
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
|
||||
x_gate2), torch.empty_like(x_up)
|
||||
x_gate1_grad, x_gate2_grad, x_up_grad = (
|
||||
torch.empty_like(x_gate1),
|
||||
torch.empty_like(x_gate2),
|
||||
torch.empty_like(x_up),
|
||||
)
|
||||
|
||||
# enqueue kernel
|
||||
_llama_act_combine_backward[(M,)](x_gate1,
|
||||
_llama_act_combine_backward[(M,)](
|
||||
x_gate1,
|
||||
x_gate2,
|
||||
x_up,
|
||||
x_gate1_grad,
|
||||
|
@ -180,6 +180,7 @@ if HAS_TRITON:
|
|||
x_up.stride(-2),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps)
|
||||
num_warps=num_warps,
|
||||
)
|
||||
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
|
||||
return x_gate_grad, x_up_grad, None, None
|
||||
|
|
|
@ -13,10 +13,18 @@ except ImportError:
|
|||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
try:
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
|
||||
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
|
||||
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import (
|
||||
token_att_fwd as lightllm_bloom_token_att_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import (
|
||||
token_att_fwd as lightllm_llama_token_att_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import (
|
||||
token_att_fwd2 as lightllm_llama_token_att_fwd2,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import (
|
||||
token_softmax_fwd as lightllm_llama_token_softmax_fwd,
|
||||
)
|
||||
|
||||
HAS_TRITON_TOKEN_ATTENTION = True
|
||||
except ImportError:
|
||||
|
@ -205,9 +213,7 @@ class Llama2TokenAttentionForwards:
|
|||
|
||||
if triton.__version__ == "2.0.0":
|
||||
prob = torch.empty_like(att_m_tensor)
|
||||
lightllm_llama_token_softmax_fwd(
|
||||
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
|
||||
)
|
||||
lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
||||
att_m_tensor = None
|
||||
|
||||
lightllm_llama_token_att_fwd2(
|
||||
|
|
|
@ -8,7 +8,9 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
|
|||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||
|
||||
from ._utils import copy_kv_to_mem_cache
|
||||
|
||||
try:
|
||||
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
|
||||
context_attention_fwd as lightllm_llama_context_attention_fwd,
|
||||
|
|
|
@ -2,13 +2,13 @@ from .api import (
|
|||
compute_global_numel,
|
||||
customized_distributed_tensor_to_param,
|
||||
distribute_tensor,
|
||||
init_as_dtensor,
|
||||
distribute_tensor_with_customization,
|
||||
init_tensor_as_customization_distributed,
|
||||
get_device_mesh,
|
||||
get_global_shape,
|
||||
get_layout,
|
||||
get_sharding_spec,
|
||||
init_as_dtensor,
|
||||
init_tensor_as_customization_distributed,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
is_sharded,
|
||||
|
|
|
@ -128,7 +128,10 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
|
|||
|
||||
return sharded_tensor
|
||||
|
||||
def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor:
|
||||
|
||||
def init_as_dtensor(
|
||||
tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size
|
||||
) -> torch.Tensor:
|
||||
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
||||
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
||||
|
||||
|
@ -140,6 +143,7 @@ def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec
|
|||
|
||||
return tensor
|
||||
|
||||
|
||||
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
||||
"""
|
||||
Convert the layout of the tensor from source_spec to target_spec.
|
||||
|
@ -468,7 +472,6 @@ def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gat
|
|||
assert callable(gather_fn), "The gather_fn must be callable."
|
||||
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
||||
|
||||
|
||||
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
||||
tensor.shard_fn = shard_fn
|
||||
tensor.gather_fn = gather_fn
|
||||
|
|
|
@ -119,9 +119,7 @@ def main():
|
|||
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
|
||||
# run pipeline forward backward
|
||||
batch = iter([batch])
|
||||
outputs = booster.execute_pipeline(
|
||||
batch, model, criterion, optimizer, return_loss=True
|
||||
)
|
||||
outputs = booster.execute_pipeline(batch, model, criterion, optimizer, return_loss=True)
|
||||
else:
|
||||
outputs = model(**batch)
|
||||
loss = criterion(outputs, None)
|
||||
|
|
|
@ -270,9 +270,7 @@ def main():
|
|||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
outputs = booster.execute_pipeline(
|
||||
dataloader_iter, model, _criterion, optimizer, return_loss=True
|
||||
)
|
||||
outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
|
||||
loss = outputs["loss"]
|
||||
else:
|
||||
batch = next(dataloader_iter)
|
||||
|
|
|
@ -285,9 +285,7 @@ def main():
|
|||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
outputs = booster.execute_pipeline(
|
||||
dataloader_iter, model, _criterion, optimizer, return_loss=True
|
||||
)
|
||||
outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
|
||||
loss = outputs["loss"]
|
||||
else:
|
||||
batch = next(dataloader_iter)
|
||||
|
|
|
@ -50,7 +50,6 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
|||
|
||||
|
||||
class Timer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
self.duration: float = 0.0
|
||||
|
@ -112,7 +111,7 @@ class PerformanceEvaluator:
|
|||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
self.num_samples += batch_size
|
||||
self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)))
|
||||
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
|
||||
|
@ -122,5 +121,6 @@ class PerformanceEvaluator:
|
|||
if dist.get_rank() == 0:
|
||||
print(
|
||||
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
||||
f"avg_throughput: {avg_throughput}")
|
||||
f"avg_throughput: {avg_throughput}"
|
||||
)
|
||||
print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")
|
||||
|
|
|
@ -16,17 +16,15 @@ def inference(args):
|
|||
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
||||
if args.model == "test":
|
||||
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")
|
||||
set_openmoe_args(config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_kernel=True)
|
||||
set_openmoe_args(
|
||||
config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=True
|
||||
)
|
||||
model = OpenMoeForCausalLM(config)
|
||||
else:
|
||||
config = LlamaConfig.from_pretrained(f"hpcai-tech/openmoe-{args.model}")
|
||||
set_openmoe_args(config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_kernel=False)
|
||||
set_openmoe_args(
|
||||
config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=False
|
||||
)
|
||||
model = OpenMoeForCausalLM.from_pretrained(f"hpcai-tech/openmoe-{args.model}", config=config)
|
||||
model = model.eval().bfloat16()
|
||||
model = model.to(torch.cuda.current_device())
|
||||
|
|
|
@ -172,9 +172,9 @@ def make_state_dict(converted_params):
|
|||
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
|
||||
"""Replaces the params in model witht the T5X converted params."""
|
||||
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
|
||||
converted = convert_t5x_to_pytorch(variables,
|
||||
num_layers=config.num_hidden_layers,
|
||||
moe_interval=config.moe_layer_interval)
|
||||
converted = convert_t5x_to_pytorch(
|
||||
variables, num_layers=config.num_hidden_layers, moe_interval=config.moe_layer_interval
|
||||
)
|
||||
state_dict = make_state_dict(converted)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
|
@ -203,11 +203,9 @@ def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
|
||||
# Required parameters
|
||||
parser.add_argument("--t5x_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the T5X checkpoint.")
|
||||
parser.add_argument(
|
||||
"--t5x_checkpoint_path", default=None, type=str, required=True, help="Path to the T5X checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
|
@ -215,10 +213,8 @@ if __name__ == "__main__":
|
|||
required=True,
|
||||
help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
||||
|
|
|
@ -41,9 +41,7 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b
|
|||
# Forward pass
|
||||
for _ in pbar:
|
||||
if use_pipeline:
|
||||
outputs = booster.execute_pipeline(
|
||||
dataloader, model, _criterion, optimizer, return_loss=True
|
||||
)
|
||||
outputs = booster.execute_pipeline(dataloader, model, _criterion, optimizer, return_loss=True)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = outputs["loss"]
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from .cpu_adam_arm import CpuAdamArmExtension
|
||||
from .cpu_adam_x86 import CpuAdamX86Extension
|
||||
|
||||
__all__ = ['CpuAdamArmExtension', 'CpuAdamX86Extension']
|
||||
|
||||
__all__ = ["CpuAdamArmExtension", "CpuAdamX86Extension"]
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .moe_cuda import MoeCudaExtension
|
||||
|
||||
__all__ = ['MoeCudaExtension']
|
||||
__all__ = ["MoeCudaExtension"]
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .fused_optimizer_cuda import FusedOptimizerCudaExtension
|
||||
|
||||
__all__ = ['FusedOptimizerCudaExtension']
|
||||
__all__ = ["FusedOptimizerCudaExtension"]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension
|
||||
from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension
|
||||
|
||||
__all__ = ['ScaledMaskedSoftmaxCudaExtension', 'ScaledUpperTriangleMaskedSoftmaxCudaExtension']
|
||||
__all__ = ["ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension"]
|
||||
|
|
|
@ -1,33 +1,33 @@
|
|||
import os
|
||||
|
||||
from . import custom, diffusers, timm, torchaudio, torchvision, transformers
|
||||
from .executor import run_fwd, run_fwd_bwd
|
||||
from .registry import model_zoo
|
||||
|
||||
# We pick a subset of models for fast testing in order to reduce the total testing time
|
||||
COMMON_MODELS = [
|
||||
'custom_hanging_param_model',
|
||||
'custom_nested_model',
|
||||
'custom_repeated_computed_layers',
|
||||
'custom_simple_net',
|
||||
'diffusers_clip_text_model',
|
||||
'diffusers_auto_encoder_kl',
|
||||
'diffusers_unet2d_model',
|
||||
'timm_densenet',
|
||||
'timm_resnet',
|
||||
'timm_swin_transformer',
|
||||
'torchaudio_wav2vec2_base',
|
||||
'torchaudio_conformer',
|
||||
'transformers_bert_for_masked_lm',
|
||||
'transformers_bloom_for_causal_lm',
|
||||
'transformers_falcon_for_causal_lm',
|
||||
'transformers_chatglm_for_conditional_generation',
|
||||
'transformers_llama_for_casual_lm',
|
||||
'transformers_vit_for_masked_image_modeling',
|
||||
'transformers_mistral_for_casual_lm'
|
||||
"custom_hanging_param_model",
|
||||
"custom_nested_model",
|
||||
"custom_repeated_computed_layers",
|
||||
"custom_simple_net",
|
||||
"diffusers_clip_text_model",
|
||||
"diffusers_auto_encoder_kl",
|
||||
"diffusers_unet2d_model",
|
||||
"timm_densenet",
|
||||
"timm_resnet",
|
||||
"timm_swin_transformer",
|
||||
"torchaudio_wav2vec2_base",
|
||||
"torchaudio_conformer",
|
||||
"transformers_bert_for_masked_lm",
|
||||
"transformers_bloom_for_causal_lm",
|
||||
"transformers_falcon_for_causal_lm",
|
||||
"transformers_chatglm_for_conditional_generation",
|
||||
"transformers_llama_for_casual_lm",
|
||||
"transformers_vit_for_masked_image_modeling",
|
||||
"transformers_mistral_for_casual_lm",
|
||||
]
|
||||
|
||||
IS_FAST_TEST = os.environ.get('FAST_TEST', '0') == '1'
|
||||
IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"
|
||||
|
||||
|
||||
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", 'COMMON_MODELS', 'IS_FAST_TEST']
|
||||
|
||||
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", "COMMON_MODELS", "IS_FAST_TEST"]
|
||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
|||
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
# ================================
|
||||
|
|
|
@ -74,9 +74,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||
data = data_gen_fn()
|
||||
model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
booster.execute_pipeline(
|
||||
_preprocess_data(data), model, _criterion, optimizer, return_loss=True
|
||||
)
|
||||
booster.execute_pipeline(_preprocess_data(data), model, _criterion, optimizer, return_loss=True)
|
||||
else:
|
||||
output = model(**_preprocess_data(data))
|
||||
loss = criterion(output)
|
||||
|
@ -108,9 +106,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||
data_for_shard = data_gen_fn()
|
||||
data_for_origin = data_gen_fn()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
booster.execute_pipeline(
|
||||
_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True
|
||||
)
|
||||
booster.execute_pipeline(_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True)
|
||||
booster.execute_pipeline(
|
||||
_preprocess_data(data_for_origin),
|
||||
new_model,
|
||||
|
|
|
@ -113,6 +113,7 @@ def check_torch_fsdp_ckpt():
|
|||
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||
|
||||
import copy
|
||||
|
||||
sharded_osd = copy.deepcopy(full_osd)
|
||||
|
||||
run_model()
|
||||
|
|
|
@ -1,16 +1,8 @@
|
|||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
|
@ -22,6 +14,7 @@ try:
|
|||
from exllama_kernels import prepare_buffers, set_tuning_params
|
||||
|
||||
from colossalai.inference.quant.gptq import CaiQuantLinear
|
||||
|
||||
HAS_AUTO_GPTQ = True
|
||||
except:
|
||||
HAS_AUTO_GPTQ = False
|
||||
|
@ -32,13 +25,14 @@ import warnings
|
|||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn('CUDA gptq is not installed')
|
||||
warnings.warn("CUDA gptq is not installed")
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
max_inner_outer_dim = 1
|
||||
max_input_len = 1
|
||||
|
@ -64,9 +58,9 @@ def init_buffer(cai_linear, use_act_order=False):
|
|||
max_input_len = 4096
|
||||
# The temp_state buffer is required to reorder X in the act-order case.
|
||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim),
|
||||
dtype=torch.float16,
|
||||
device=torch.cuda.current_device())
|
||||
gptq_temp_state_buffer = torch.zeros(
|
||||
(max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
||||
)
|
||||
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
|
||||
|
||||
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
|
||||
|
@ -77,10 +71,11 @@ def init_buffer(cai_linear, use_act_order=False):
|
|||
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
||||
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq")
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
||||
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq",
|
||||
)
|
||||
def test_gptq_linear():
|
||||
|
||||
infeature = 1024
|
||||
outfeature = 1024
|
||||
group_size = 128
|
||||
|
@ -120,7 +115,7 @@ def test_gptq_linear():
|
|||
max_input_len = 2048
|
||||
buffers = {
|
||||
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
|
||||
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
||||
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device),
|
||||
}
|
||||
|
||||
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
|
||||
|
@ -146,5 +141,4 @@ def test_gptq_linear():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
test_gptq_linear()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
@ -17,6 +17,7 @@ def check_params_equal(model, torch_model):
|
|||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert torch.allclose(p, torch_p, atol=1e-3), f"diff: {torch.abs(p - torch_p)}"
|
||||
|
||||
|
||||
# TODO Something wrong with ci when running this test.
|
||||
@pytest.mark.skip(reason="skip because of something wrong with CI")
|
||||
@clear_cache_before_run()
|
||||
|
|
|
@ -103,9 +103,7 @@ def run_pp(
|
|||
torch_loss = criterion(torch_output)
|
||||
torch_loss.backward()
|
||||
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||
)
|
||||
pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True)
|
||||
|
||||
# check loss
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
|
|
|
@ -99,9 +99,7 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
|||
torch_output = torch_model(input_list[0])
|
||||
torch_loss = criterion(torch_output)
|
||||
torch_loss.backward()
|
||||
pp_ret = schedule.forward_backward_step(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||
)
|
||||
pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True)
|
||||
|
||||
# check loss
|
||||
if stage_manager.is_last_stage():
|
||||
|
|
Loading…
Reference in New Issue