mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cipull/5842/head
parent
4c69e2dc91
commit
df612434c9
|
@ -8,11 +8,10 @@ import argparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,8 +10,8 @@ import os
|
||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.optimizer import Optimizer
|
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
|
@ -1,20 +1,19 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional, List, Dict, Tuple, Callable, Any
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
from transformers.utils import logging
|
|
||||||
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_template(
|
def get_prompt_template(
|
||||||
input_query:str,
|
input_query: str,
|
||||||
history:List[Dict]= None,
|
history: List[Dict] = None,
|
||||||
roles:list = ["", "Human", "Assistant"],
|
roles: list = ["", "Human", "Assistant"],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generates a prompt template for chat models based on input and history.
|
Generates a prompt template for chat models based on input and history.
|
||||||
|
@ -48,6 +47,7 @@ def get_prompt_template(
|
||||||
prompt += f"{role}: <s>"
|
prompt += f"{role}: <s>"
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def streaming_chat(
|
def streaming_chat(
|
||||||
model: Any,
|
model: Any,
|
||||||
|
@ -99,14 +99,14 @@ def streaming_chat(
|
||||||
logits_processor = LogitsProcessorList()
|
logits_processor = LogitsProcessorList()
|
||||||
|
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
'temperature': temperature,
|
"temperature": temperature,
|
||||||
'top_p': top_p,
|
"top_p": top_p,
|
||||||
'top_k': top_k,
|
"top_k": top_k,
|
||||||
'do_sample': do_sample,
|
"do_sample": do_sample,
|
||||||
'max_new_tokens': max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
'length_penalty': length_penalty,
|
"length_penalty": length_penalty,
|
||||||
'use_cache': True,
|
"use_cache": True,
|
||||||
**kwargs
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt_str = get_prompt_template(input_query, history=history, roles=roles)
|
prompt_str = get_prompt_template(input_query, history=history, roles=roles)
|
||||||
|
@ -116,13 +116,18 @@ def streaming_chat(
|
||||||
history.append({"role": roles[1], "message": input_query.strip()})
|
history.append({"role": roles[1], "message": input_query.strip()})
|
||||||
history.append({"role": roles[2], "message": None})
|
history.append({"role": roles[2], "message": None})
|
||||||
|
|
||||||
for outputs in stream_generate(model, **inputs, past_key_values=past_key_values,
|
for outputs in stream_generate(
|
||||||
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
|
model,
|
||||||
**generation_kwargs):
|
**inputs,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
return_past_key_values=return_past_key_values,
|
||||||
|
**generation_kwargs,
|
||||||
|
):
|
||||||
if return_past_key_values:
|
if return_past_key_values:
|
||||||
outputs, past_key_values = outputs
|
outputs, past_key_values = outputs
|
||||||
|
|
||||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
|
||||||
response = tokenizer.decode(outputs)
|
response = tokenizer.decode(outputs)
|
||||||
|
|
||||||
history[-1]["message"] = response.strip()
|
history[-1]["message"] = response.strip()
|
||||||
|
|
|
@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs):
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
|
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
|
||||||
except OSError:
|
except OSError:
|
||||||
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
|
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
|
||||||
|
|
||||||
|
|
|
@ -12,4 +12,3 @@ flash-attn>=2.0.0,<=2.0.5
|
||||||
tqdm
|
tqdm
|
||||||
sentencepiece==0.1.99
|
sentencepiece==0.1.99
|
||||||
protobuf<=3.20.0
|
protobuf<=3.20.0
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
from colossal_llama2.utils.stream_chat_patch import streaming_chat
|
from colossal_llama2.utils.stream_chat_patch import streaming_chat
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval()
|
model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval()
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
||||||
|
@ -27,29 +27,34 @@ def main(args):
|
||||||
print(f"\n{roles[2]}: ", end="")
|
print(f"\n{roles[2]}: ", end="")
|
||||||
gen_len = 0
|
gen_len = 0
|
||||||
for response, history, past_key_values in streaming_chat(
|
for response, history, past_key_values in streaming_chat(
|
||||||
model, tokenizer, input_query, history=history, roles=roles,
|
model,
|
||||||
temperature = args.temperature,
|
tokenizer,
|
||||||
top_p = args.top_p,
|
input_query,
|
||||||
top_k = args.top_k,
|
history=history,
|
||||||
do_sample = args.do_sample,
|
roles=roles,
|
||||||
length_penalty = args.length_penalty,
|
temperature=args.temperature,
|
||||||
max_new_tokens = args.max_new_tokens,
|
top_p=args.top_p,
|
||||||
|
top_k=args.top_k,
|
||||||
|
do_sample=args.do_sample,
|
||||||
|
length_penalty=args.length_penalty,
|
||||||
|
max_new_tokens=args.max_new_tokens,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
return_past_key_values=True):
|
return_past_key_values=True,
|
||||||
|
):
|
||||||
output = response[gen_len:]
|
output = response[gen_len:]
|
||||||
print(output, end="", flush=True)
|
print(output, end="", flush=True)
|
||||||
gen_len = len(response)
|
gen_len = len(response)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model_path', type=str, default=None, help="path to chat version model")
|
parser.add_argument("--model_path", type=str, default=None, help="path to chat version model")
|
||||||
parser.add_argument('--tokenizer_path', type=str, default=None, help="path to chat version tokenizer")
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="path to chat version tokenizer")
|
||||||
parser.add_argument('--temperature', type=float, default=0.8, help="set temperature")
|
parser.add_argument("--temperature", type=float, default=0.8, help="set temperature")
|
||||||
parser.add_argument('--top_p', type=float, default=0.95, help="set top p value")
|
parser.add_argument("--top_p", type=float, default=0.95, help="set top p value")
|
||||||
parser.add_argument('--top_k', type=int, default=50, help="set top k value")
|
parser.add_argument("--top_k", type=int, default=50, help="set top k value")
|
||||||
parser.add_argument('--do_sample', type=bool, default=True, help="whether turn on do_sample or not")
|
parser.add_argument("--do_sample", type=bool, default=True, help="whether turn on do_sample or not")
|
||||||
parser.add_argument('--length_penalty', type=float, default=1.2, help="set length penalty")
|
parser.add_argument("--length_penalty", type=float, default=1.2, help="set length penalty")
|
||||||
parser.add_argument('--max_new_tokens', type=int, default=512, help="set max new tokens")
|
parser.add_argument("--max_new_tokens", type=int, default=512, help="set max new tokens")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
|
@ -20,13 +20,13 @@ import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
# check lora compatibility
|
# check lora compatibility
|
||||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||||
|
|
|
@ -3,7 +3,6 @@ import copy
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossal_eval import dataset, models, utils
|
from colossal_eval import dataset, models, utils
|
||||||
|
|
||||||
|
|
Binary file not shown.
|
@ -106,6 +106,5 @@ def main():
|
||||||
print(f"[{coordinator.rank}] {outputs}")
|
print(f"[{coordinator.rank}] {outputs}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -24,6 +24,7 @@ from langchain.pydantic_v1 import Field
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
|
||||||
|
|
||||||
class CustomBaseRetrievalQA(BaseRetrievalQA):
|
class CustomBaseRetrievalQA(BaseRetrievalQA):
|
||||||
"""Base class for question-answering chains."""
|
"""Base class for question-answering chains."""
|
||||||
|
|
||||||
|
@ -98,7 +99,6 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
|
||||||
for k, v in inputs.items()
|
for k, v in inputs.items()
|
||||||
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
|
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
|
||||||
}
|
}
|
||||||
answers = []
|
|
||||||
if self.combine_documents_chain.memory is not None:
|
if self.combine_documents_chain.memory is not None:
|
||||||
buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
|
buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
|
||||||
self.combine_documents_chain.memory.buffered_history
|
self.combine_documents_chain.memory.buffered_history
|
||||||
|
@ -117,10 +117,10 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
|
||||||
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
|
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
|
||||||
|
|
||||||
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
||||||
rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', [])
|
rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
|
||||||
answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) else None
|
answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) else None
|
||||||
if answer is None:
|
if answer is None:
|
||||||
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
|
answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
|
||||||
if self.combine_documents_chain.memory is not None:
|
if self.combine_documents_chain.memory is not None:
|
||||||
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
||||||
|
|
||||||
|
@ -161,10 +161,14 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
|
||||||
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
|
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
|
||||||
)
|
)
|
||||||
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
||||||
rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', [])
|
rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
|
||||||
answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords)==0 else None
|
answer = (
|
||||||
|
answer
|
||||||
|
if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords) == 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
if answer is None:
|
if answer is None:
|
||||||
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
|
answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
|
||||||
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
||||||
|
|
||||||
if self.return_source_documents:
|
if self.return_source_documents:
|
||||||
|
|
|
@ -1,32 +1,33 @@
|
||||||
'''
|
"""
|
||||||
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
|
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
import glob
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from colossalqa.utils import drop_table
|
|
||||||
from colossalqa.mylogging import get_logger
|
from colossalqa.mylogging import get_logger
|
||||||
|
from colossalqa.utils import drop_table
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
SUPPORTED_DATA_FORMAT = ['.csv','.xlsx', '.xls','.json','.html','.h5', '.hdf5','.parquet','.feather','.dta']
|
SUPPORTED_DATA_FORMAT = [".csv", ".xlsx", ".xls", ".json", ".html", ".h5", ".hdf5", ".parquet", ".feather", ".dta"]
|
||||||
|
|
||||||
|
|
||||||
class TableLoader:
|
class TableLoader:
|
||||||
'''
|
"""
|
||||||
Load tables from different files and serve a sql database for database operations
|
Load tables from different files and serve a sql database for database operations
|
||||||
'''
|
"""
|
||||||
def __init__(self, files: str,
|
|
||||||
sql_path:str='sqlite:///mydatabase.db',
|
def __init__(self, files: str, sql_path: str = "sqlite:///mydatabase.db", verbose=False, **kwargs) -> None:
|
||||||
verbose=False, **kwargs) -> None:
|
"""
|
||||||
'''
|
|
||||||
Args:
|
Args:
|
||||||
files: list of files (list[file path, name])
|
files: list of files (list[file path, name])
|
||||||
sql_path: how to serve the sql database
|
sql_path: how to serve the sql database
|
||||||
**kwargs: keyword type arguments, useful for certain document types
|
**kwargs: keyword type arguments, useful for certain document types
|
||||||
'''
|
"""
|
||||||
self.data = {}
|
self.data = {}
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.sql_path = sql_path
|
self.sql_path = sql_path
|
||||||
|
@ -49,58 +50,58 @@ class TableLoader:
|
||||||
self.to_sql(path, dataset_name)
|
self.to_sql(path, dataset_name)
|
||||||
|
|
||||||
def load_data(self, path):
|
def load_data(self, path):
|
||||||
'''
|
"""
|
||||||
Load data and serve the data as sql database.
|
Load data and serve the data as sql database.
|
||||||
Data must be in pandas format
|
Data must be in pandas format
|
||||||
'''
|
"""
|
||||||
files = []
|
files = []
|
||||||
# Handle glob expression
|
# Handle glob expression
|
||||||
try:
|
try:
|
||||||
files = glob.glob(path)
|
files = glob.glob(path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
if len(files)==0:
|
if len(files) == 0:
|
||||||
raise ValueError("Unsupported file/directory format. For directories, please use glob expression")
|
raise ValueError("Unsupported file/directory format. For directories, please use glob expression")
|
||||||
elif len(files)==1:
|
elif len(files) == 1:
|
||||||
path = files[0]
|
path = files[0]
|
||||||
else:
|
else:
|
||||||
for file in files:
|
for file in files:
|
||||||
self.load_data(file)
|
self.load_data(file)
|
||||||
|
|
||||||
if path.endswith('.csv'):
|
if path.endswith(".csv"):
|
||||||
# Load csv
|
# Load csv
|
||||||
self.data[path] = pd.read_csv(path)
|
self.data[path] = pd.read_csv(path)
|
||||||
elif path.endswith('.xlsx') or path.endswith('.xls'):
|
elif path.endswith(".xlsx") or path.endswith(".xls"):
|
||||||
# Load excel
|
# Load excel
|
||||||
self.data[path] = pd.read_excel(path) # You can adjust the sheet_name as needed
|
self.data[path] = pd.read_excel(path) # You can adjust the sheet_name as needed
|
||||||
elif path.endswith('.json'):
|
elif path.endswith(".json"):
|
||||||
# Load json
|
# Load json
|
||||||
self.data[path] = pd.read_json(path)
|
self.data[path] = pd.read_json(path)
|
||||||
elif path.endswith('.html'):
|
elif path.endswith(".html"):
|
||||||
# Load html
|
# Load html
|
||||||
html_tables = pd.read_html(path)
|
html_tables = pd.read_html(path)
|
||||||
# Choose the desired table from the list of DataFrame objects
|
# Choose the desired table from the list of DataFrame objects
|
||||||
self.data[path] = html_tables[0] # You may need to adjust this index
|
self.data[path] = html_tables[0] # You may need to adjust this index
|
||||||
elif path.endswith('.h5') or path.endswith('.hdf5'):
|
elif path.endswith(".h5") or path.endswith(".hdf5"):
|
||||||
# Load h5
|
# Load h5
|
||||||
self.data[path] = pd.read_hdf(path, key=self.kwargs.get('key', 'data')) # You can adjust the key as needed
|
self.data[path] = pd.read_hdf(path, key=self.kwargs.get("key", "data")) # You can adjust the key as needed
|
||||||
elif path.endswith('.parquet'):
|
elif path.endswith(".parquet"):
|
||||||
# Load parquet
|
# Load parquet
|
||||||
self.data[path] = pd.read_parquet(path, engine='fastparquet')
|
self.data[path] = pd.read_parquet(path, engine="fastparquet")
|
||||||
elif path.endswith('.feather'):
|
elif path.endswith(".feather"):
|
||||||
# Load feather
|
# Load feather
|
||||||
self.data[path] = pd.read_feather(path)
|
self.data[path] = pd.read_feather(path)
|
||||||
elif path.endswith('.dta'):
|
elif path.endswith(".dta"):
|
||||||
# Load dta
|
# Load dta
|
||||||
self.data[path] = pd.read_stata(path)
|
self.data[path] = pd.read_stata(path)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported file format")
|
raise ValueError("Unsupported file format")
|
||||||
|
|
||||||
def to_sql(self, path, table_name):
|
def to_sql(self, path, table_name):
|
||||||
'''
|
"""
|
||||||
Serve the data as sql database.
|
Serve the data as sql database.
|
||||||
'''
|
"""
|
||||||
self.data[path].to_sql(table_name, con=self.sql_engine, if_exists='replace', index=False)
|
self.data[path].to_sql(table_name, con=self.sql_engine, if_exists="replace", index=False)
|
||||||
logger.info(f"Loaded to Sqlite3\nPath: {path}", verbose=self.verbose)
|
logger.info(f"Loaded to Sqlite3\nPath: {path}", verbose=self.verbose)
|
||||||
return self.sql_path
|
return self.sql_path
|
||||||
|
|
||||||
|
@ -113,7 +114,3 @@ class TableLoader:
|
||||||
self.sql_engine.dispose()
|
self.sql_engine.dispose()
|
||||||
del self.data
|
del self.data
|
||||||
del self.sql_engine
|
del self.sql_engine
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ print(resp) # super-heavyweight awesome-natured yawning Australian creature!
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
from typing import Any, List, Mapping, Optional
|
from typing import Any, Mapping
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
@ -33,11 +33,11 @@ class ColossalCloudLLM(LLM):
|
||||||
A custom LLM class that integrates LLMs running on the ColossalCloud Platform
|
A custom LLM class that integrates LLMs running on the ColossalCloud Platform
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n: int
|
n: int
|
||||||
gen_config: dict = None
|
gen_config: dict = None
|
||||||
auth_config: dict = None
|
auth_config: dict = None
|
||||||
valid_gen_para: list = ['max_new_tokens', 'top_k',
|
valid_gen_para: list = ["max_new_tokens", "top_k", "top_p", "temperature", "repetition_penalty"]
|
||||||
'top_p', 'temperature', 'repetition_penalty']
|
|
||||||
|
|
||||||
def __init__(self, gen_config=None, **kwargs):
|
def __init__(self, gen_config=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -63,15 +63,15 @@ class ColossalCloudLLM(LLM):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return 'ColossalCloudLLM'
|
return "ColossalCloudLLM"
|
||||||
|
|
||||||
def set_auth_config(self, **kwargs):
|
def set_auth_config(self, **kwargs):
|
||||||
url = get_from_dict_or_env(kwargs, "url", "URL")
|
url = get_from_dict_or_env(kwargs, "url", "URL")
|
||||||
host = get_from_dict_or_env(kwargs, "host", "HOST")
|
host = get_from_dict_or_env(kwargs, "host", "HOST")
|
||||||
|
|
||||||
auth_config = {}
|
auth_config = {}
|
||||||
auth_config['endpoint'] = url
|
auth_config["endpoint"] = url
|
||||||
auth_config['Host'] = host
|
auth_config["Host"] = host
|
||||||
self.auth_config = auth_config
|
self.auth_config = auth_config
|
||||||
|
|
||||||
def _call(self, prompt: str, stop=None, **kwargs: Any) -> str:
|
def _call(self, prompt: str, stop=None, **kwargs: Any) -> str:
|
||||||
|
@ -86,7 +86,9 @@ class ColossalCloudLLM(LLM):
|
||||||
# Update the generation arguments
|
# Update the generation arguments
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key not in self.valid_gen_para:
|
if key not in self.valid_gen_para:
|
||||||
raise KeyError(f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}")
|
raise KeyError(
|
||||||
|
f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}"
|
||||||
|
)
|
||||||
if key in self.gen_config:
|
if key in self.gen_config:
|
||||||
self.gen_config[key] = value
|
self.gen_config[key] = value
|
||||||
|
|
||||||
|
@ -98,26 +100,16 @@ class ColossalCloudLLM(LLM):
|
||||||
resp_text = resp_text.split(stopping_words)[0]
|
resp_text = resp_text.split(stopping_words)[0]
|
||||||
return resp_text
|
return resp_text
|
||||||
|
|
||||||
|
|
||||||
def text_completion(self, prompt, gen_config, auth_config):
|
def text_completion(self, prompt, gen_config, auth_config):
|
||||||
# Complusory Parameters
|
# Complusory Parameters
|
||||||
endpoint = auth_config.pop('endpoint')
|
endpoint = auth_config.pop("endpoint")
|
||||||
max_new_tokens = gen_config.pop('max_new_tokens')
|
max_new_tokens = gen_config.pop("max_new_tokens")
|
||||||
# Optional Parameters
|
# Optional Parameters
|
||||||
optional_params = ['top_k', 'top_p', 'temperature', 'repetition_penalty'] # Self.optional
|
optional_params = ["top_k", "top_p", "temperature", "repetition_penalty"] # Self.optional
|
||||||
gen_config = {key: gen_config[key] for key in optional_params if key in gen_config}
|
gen_config = {key: gen_config[key] for key in optional_params if key in gen_config}
|
||||||
# Define the data payload
|
# Define the data payload
|
||||||
data = {
|
data = {"max_new_tokens": max_new_tokens, "history": [{"instruction": prompt, "response": ""}], **gen_config}
|
||||||
"max_new_tokens": max_new_tokens,
|
headers = {"Content-Type": "application/json", **auth_config} # 'Host',
|
||||||
"history": [
|
|
||||||
{"instruction": prompt, "response": ""}
|
|
||||||
],
|
|
||||||
**gen_config
|
|
||||||
}
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
**auth_config # 'Host',
|
|
||||||
}
|
|
||||||
# Make the POST request
|
# Make the POST request
|
||||||
response = requests.post(endpoint, headers=headers, data=json.dumps(data))
|
response = requests.post(endpoint, headers=headers, data=json.dumps(data))
|
||||||
response.raise_for_status() # raise error if return code is not 200(success)
|
response.raise_for_status() # raise error if return code is not 200(success)
|
||||||
|
|
|
@ -193,4 +193,3 @@ class VllmLLM(LLM):
|
||||||
def _identifying_params(self) -> Mapping[str, int]:
|
def _identifying_params(self) -> Mapping[str, int]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
return {"n": self.n}
|
return {"n": self.n}
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ All custom prompt templates are defined here.
|
||||||
|
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
# Below are Chinese retrieval qa prompts
|
# Below are Chinese retrieval qa prompts
|
||||||
|
|
||||||
_CUSTOM_SUMMARIZER_TEMPLATE_ZH = """请递进式地总结所提供的当前对话,将当前对话的摘要内容添加到先前已有的摘要上,返回一个融合了当前对话的新的摘要。
|
_CUSTOM_SUMMARIZER_TEMPLATE_ZH = """请递进式地总结所提供的当前对话,将当前对话的摘要内容添加到先前已有的摘要上,返回一个融合了当前对话的新的摘要。
|
||||||
|
|
|
@ -99,13 +99,7 @@ class CustomRetriever(BaseRetriever):
|
||||||
def clear_documents(self):
|
def clear_documents(self):
|
||||||
"""Clear all document vectors from database"""
|
"""Clear all document vectors from database"""
|
||||||
for source in self.vector_stores:
|
for source in self.vector_stores:
|
||||||
index(
|
index([], self.record_managers[source], self.vector_stores[source], cleanup="full", source_id_key="source")
|
||||||
[],
|
|
||||||
self.record_managers[source],
|
|
||||||
self.vector_stores[source],
|
|
||||||
cleanup="full",
|
|
||||||
source_id_key="source"
|
|
||||||
)
|
|
||||||
self.vector_stores = {}
|
self.vector_stores = {}
|
||||||
self.sql_index_database = {}
|
self.sql_index_database = {}
|
||||||
self.record_managers = {}
|
self.record_managers = {}
|
||||||
|
|
|
@ -1,22 +1,27 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from colossalqa.retrieval_conversation_universal import UniversalRetrievalConversation
|
from colossalqa.retrieval_conversation_universal import UniversalRetrievalConversation
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--en_model_path', type=str, default=None)
|
parser.add_argument("--en_model_path", type=str, default=None)
|
||||||
parser.add_argument('--zh_model_path', type=str, default=None)
|
parser.add_argument("--zh_model_path", type=str, default=None)
|
||||||
parser.add_argument('--zh_model_name', type=str, default=None)
|
parser.add_argument("--zh_model_name", type=str, default=None)
|
||||||
parser.add_argument('--en_model_name', type=str, default=None)
|
parser.add_argument("--en_model_name", type=str, default=None)
|
||||||
parser.add_argument('--sql_file_path', type=str, default=None, help='path to the a empty folder for storing sql files for indexing')
|
parser.add_argument(
|
||||||
|
"--sql_file_path", type=str, default=None, help="path to the a empty folder for storing sql files for indexing"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Will ask for documents path in running time
|
# Will ask for documents path in running time
|
||||||
session = UniversalRetrievalConversation(files_en=None,
|
session = UniversalRetrievalConversation(
|
||||||
|
files_en=None,
|
||||||
files_zh=None,
|
files_zh=None,
|
||||||
zh_model_path=args.zh_model_path, en_model_path=args.en_model_path,
|
zh_model_path=args.zh_model_path,
|
||||||
zh_model_name=args.zh_model_name, en_model_name=args.en_model_name,
|
en_model_path=args.en_model_path,
|
||||||
sql_file_path=args.sql_file_path
|
zh_model_name=args.zh_model_name,
|
||||||
|
en_model_name=args.en_model_name,
|
||||||
|
sql_file_path=args.sql_file_path,
|
||||||
)
|
)
|
||||||
session.start_test_session()
|
session.start_test_session()
|
||||||
|
|
|
@ -5,13 +5,7 @@ from colossalqa.chain.retrieval_qa.base import RetrievalQA
|
||||||
from colossalqa.data_loader.document_loader import DocumentLoader
|
from colossalqa.data_loader.document_loader import DocumentLoader
|
||||||
from colossalqa.memory import ConversationBufferWithSummary
|
from colossalqa.memory import ConversationBufferWithSummary
|
||||||
from colossalqa.mylogging import get_logger
|
from colossalqa.mylogging import get_logger
|
||||||
from colossalqa.prompt.prompt import (
|
from colossalqa.prompt.prompt import ZH_RETRIEVAL_QA_REJECTION_ANSWER, ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS
|
||||||
PROMPT_DISAMBIGUATE_ZH,
|
|
||||||
PROMPT_RETRIEVAL_QA_ZH,
|
|
||||||
SUMMARY_PROMPT_ZH,
|
|
||||||
ZH_RETRIEVAL_QA_REJECTION_ANSWER,
|
|
||||||
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
|
|
||||||
)
|
|
||||||
from colossalqa.retriever import CustomRetriever
|
from colossalqa.retriever import CustomRetriever
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
|
@ -1,58 +1,30 @@
|
||||||
from colossalqa.prompt.prompt import (
|
from colossalqa.prompt.prompt import PROMPT_DISAMBIGUATE_ZH, PROMPT_RETRIEVAL_QA_ZH, SUMMARY_PROMPT_ZH
|
||||||
PROMPT_DISAMBIGUATE_ZH,
|
|
||||||
PROMPT_RETRIEVAL_QA_ZH,
|
|
||||||
SUMMARY_PROMPT_ZH,
|
|
||||||
ZH_RETRIEVAL_QA_REJECTION_ANSWER,
|
|
||||||
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
|
|
||||||
)
|
|
||||||
from colossalqa.text_splitter import ChineseTextSplitter
|
from colossalqa.text_splitter import ChineseTextSplitter
|
||||||
|
|
||||||
ALL_CONFIG = {
|
ALL_CONFIG = {
|
||||||
"embed": {
|
"embed": {
|
||||||
"embed_name": "m3e", # embedding model name
|
"embed_name": "m3e", # embedding model name
|
||||||
"embed_model_name_or_path": "moka-ai/m3e-base", # path to embedding model, could be a local path or a huggingface path
|
"embed_model_name_or_path": "moka-ai/m3e-base", # path to embedding model, could be a local path or a huggingface path
|
||||||
"embed_model_device": {
|
"embed_model_device": {"device": "cpu"},
|
||||||
"device": "cpu"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"mode": "api", # "local" for loading models, "api" for using model api
|
"mode": "api", # "local" for loading models, "api" for using model api
|
||||||
"model_name": "chatgpt_api", # local model name, "chatgpt_api" or "pangu_api"
|
"model_name": "chatgpt_api", # local model name, "chatgpt_api" or "pangu_api"
|
||||||
"model_path": "", # path to the model, could be a local path or a huggingface path. don't need if using an api
|
"model_path": "", # path to the model, could be a local path or a huggingface path. don't need if using an api
|
||||||
"device": {
|
"device": {"device": "cuda"},
|
||||||
"device": "cuda"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"splitter": {
|
|
||||||
"name": ChineseTextSplitter
|
|
||||||
},
|
|
||||||
"retrieval": {
|
|
||||||
"retri_top_k": 3,
|
|
||||||
"retri_kb_file_path": "./", # path to store database files
|
|
||||||
"verbose": True
|
|
||||||
},
|
},
|
||||||
|
"splitter": {"name": ChineseTextSplitter},
|
||||||
|
"retrieval": {"retri_top_k": 3, "retri_kb_file_path": "./", "verbose": True}, # path to store database files
|
||||||
"chain": {
|
"chain": {
|
||||||
"mem_summary_prompt": SUMMARY_PROMPT_ZH, # summary prompt template
|
"mem_summary_prompt": SUMMARY_PROMPT_ZH, # summary prompt template
|
||||||
"mem_human_prefix": "用户",
|
"mem_human_prefix": "用户",
|
||||||
"mem_ai_prefix": "Assistant",
|
"mem_ai_prefix": "Assistant",
|
||||||
"mem_max_tokens": 2000,
|
"mem_max_tokens": 2000,
|
||||||
"mem_llm_kwargs": {
|
"mem_llm_kwargs": {"max_new_tokens": 50, "temperature": 1, "do_sample": True},
|
||||||
"max_new_tokens": 50,
|
|
||||||
"temperature": 1,
|
|
||||||
"do_sample": True
|
|
||||||
},
|
|
||||||
"disambig_prompt": PROMPT_DISAMBIGUATE_ZH, # disambiguate prompt template
|
"disambig_prompt": PROMPT_DISAMBIGUATE_ZH, # disambiguate prompt template
|
||||||
"disambig_llm_kwargs": {
|
"disambig_llm_kwargs": {"max_new_tokens": 30, "temperature": 1, "do_sample": True},
|
||||||
"max_new_tokens": 30,
|
"gen_llm_kwargs": {"max_new_tokens": 100, "temperature": 1, "do_sample": True},
|
||||||
"temperature": 1,
|
|
||||||
"do_sample": True
|
|
||||||
},
|
|
||||||
"gen_llm_kwargs": {
|
|
||||||
"max_new_tokens": 100,
|
|
||||||
"temperature": 1,
|
|
||||||
"do_sample": True
|
|
||||||
},
|
|
||||||
"gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH, # generation prompt template
|
"gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH, # generation prompt template
|
||||||
"verbose": True
|
"verbose": True,
|
||||||
}
|
},
|
||||||
}
|
}
|
|
@ -1,27 +1,18 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
|
||||||
from colossalqa.local.llm import ColossalAPI, ColossalLLM
|
|
||||||
from colossalqa.data_loader.document_loader import DocumentLoader
|
|
||||||
from colossalqa.mylogging import get_logger
|
|
||||||
from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
|
|
||||||
from colossalqa.retriever import CustomRetriever
|
|
||||||
from enum import Enum
|
|
||||||
from fastapi import FastAPI, Request
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
import config
|
import config
|
||||||
|
import uvicorn
|
||||||
|
from colossalqa.local.llm import ColossalAPI, ColossalLLM
|
||||||
|
from colossalqa.mylogging import get_logger
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from pydantic import BaseModel
|
||||||
from RAG_ChatBot import RAG_ChatBot
|
from RAG_ChatBot import RAG_ChatBot
|
||||||
from utils import DocAction
|
from utils import DocAction
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def parseArgs():
|
def parseArgs():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--http_host", default="0.0.0.0")
|
parser.add_argument("--http_host", default="0.0.0.0")
|
||||||
|
@ -36,6 +27,7 @@ class DocUpdateReq(BaseModel):
|
||||||
doc_files: Union[List[str], str, None] = None
|
doc_files: Union[List[str], str, None] = None
|
||||||
action: DocAction = DocAction.ADD
|
action: DocAction = DocAction.ADD
|
||||||
|
|
||||||
|
|
||||||
class GenerationTaskReq(BaseModel):
|
class GenerationTaskReq(BaseModel):
|
||||||
user_input: str
|
user_input: str
|
||||||
|
|
||||||
|
@ -45,7 +37,7 @@ def update_docs(data: DocUpdateReq, request: Request):
|
||||||
if data.action == "add":
|
if data.action == "add":
|
||||||
if isinstance(data.doc_files, str):
|
if isinstance(data.doc_files, str):
|
||||||
data.doc_files = [data.doc_files]
|
data.doc_files = [data.doc_files]
|
||||||
chatbot.load_doc_from_files(files = data.doc_files)
|
chatbot.load_doc_from_files(files=data.doc_files)
|
||||||
all_docs = ""
|
all_docs = ""
|
||||||
for doc in chatbot.docs_names:
|
for doc in chatbot.docs_names:
|
||||||
all_docs += f"\t{doc}\n\n"
|
all_docs += f"\t{doc}\n\n"
|
||||||
|
@ -84,12 +76,13 @@ if __name__ == "__main__":
|
||||||
"user": "User",
|
"user": "User",
|
||||||
"max_tokens": all_config["chain"]["disambig_llm_kwargs"]["max_new_tokens"],
|
"max_tokens": all_config["chain"]["disambig_llm_kwargs"]["max_new_tokens"],
|
||||||
"temperature": all_config["chain"]["disambig_llm_kwargs"]["temperature"],
|
"temperature": all_config["chain"]["disambig_llm_kwargs"]["temperature"],
|
||||||
"n": 1 # the number of responses generated
|
"n": 1, # the number of responses generated
|
||||||
}
|
}
|
||||||
llm = Pangu(gen_config=gen_config)
|
llm = Pangu(gen_config=gen_config)
|
||||||
llm.set_auth_config() # verify user's auth info here
|
llm.set_auth_config() # verify user's auth info here
|
||||||
elif model_name == "chatgpt_api":
|
elif model_name == "chatgpt_api":
|
||||||
from langchain.llms import OpenAI
|
from langchain.llms import OpenAI
|
||||||
|
|
||||||
llm = OpenAI()
|
llm = OpenAI()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported mode.")
|
raise ValueError("Unsupported mode.")
|
||||||
|
|
|
@ -1,24 +1,26 @@
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import requests
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import requests
|
||||||
from utils import DocAction
|
from utils import DocAction
|
||||||
|
|
||||||
|
|
||||||
def parseArgs():
|
def parseArgs():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--http_host", default="0.0.0.0")
|
parser.add_argument("--http_host", default="0.0.0.0")
|
||||||
parser.add_argument("--http_port", type=int, default=13666)
|
parser.add_argument("--http_port", type=int, default=13666)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def get_response(data, url):
|
def get_response(data, url):
|
||||||
headers = {"Content-type": "application/json"}
|
headers = {"Content-type": "application/json"}
|
||||||
response = requests.post(url, json=data, headers=headers)
|
response = requests.post(url, json=data, headers=headers)
|
||||||
response = json.loads(response.content)
|
response = json.loads(response.content)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def add_text(history, text):
|
def add_text(history, text):
|
||||||
history = history + [(text, None)]
|
history = history + [(text, None)]
|
||||||
return history, gr.update(value=None, interactive=True)
|
return history, gr.update(value=None, interactive=True)
|
||||||
|
@ -28,18 +30,14 @@ def add_file(history, files):
|
||||||
files_string = "\n".join([os.path.basename(file.name) for file in files])
|
files_string = "\n".join([os.path.basename(file.name) for file in files])
|
||||||
|
|
||||||
doc_files = [file.name for file in files]
|
doc_files = [file.name for file in files]
|
||||||
data = {
|
data = {"doc_files": doc_files, "action": DocAction.ADD}
|
||||||
"doc_files": doc_files,
|
|
||||||
"action": DocAction.ADD
|
|
||||||
}
|
|
||||||
response = get_response(data, update_url)["response"]
|
response = get_response(data, update_url)["response"]
|
||||||
history = history + [(files_string, response)]
|
history = history + [(files_string, response)]
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
def bot(history):
|
def bot(history):
|
||||||
data = {
|
data = {"user_input": history[-1][0].strip()}
|
||||||
"user_input": history[-1][0].strip()
|
|
||||||
}
|
|
||||||
response = get_response(data, gen_url)
|
response = get_response(data, gen_url)
|
||||||
|
|
||||||
if response["error"] != "":
|
if response["error"] != "":
|
||||||
|
@ -51,11 +49,8 @@ def bot(history):
|
||||||
|
|
||||||
def restart(chatbot, txt):
|
def restart(chatbot, txt):
|
||||||
# Reset the conversation state and clear the chat history
|
# Reset the conversation state and clear the chat history
|
||||||
data = {
|
data = {"doc_files": "", "action": DocAction.CLEAR}
|
||||||
"doc_files": "",
|
get_response(data, update_url)
|
||||||
"action": DocAction.CLEAR
|
|
||||||
}
|
|
||||||
response = get_response(data, update_url)
|
|
||||||
|
|
||||||
return gr.update(value=None), gr.update(value=None, interactive=True)
|
return gr.update(value=None), gr.update(value=None, interactive=True)
|
||||||
|
|
||||||
|
|
|
@ -1,21 +1,21 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from colossalqa.data_loader.document_loader import DocumentLoader
|
from colossalqa.data_loader.document_loader import DocumentLoader
|
||||||
|
|
||||||
|
|
||||||
def test_add_document():
|
def test_add_document():
|
||||||
PATH = os.environ.get('TEST_DOCUMENT_LOADER_DATA_PATH')
|
PATH = os.environ.get("TEST_DOCUMENT_LOADER_DATA_PATH")
|
||||||
files = [[PATH, 'all data']]
|
files = [[PATH, "all data"]]
|
||||||
document_loader = DocumentLoader(files)
|
document_loader = DocumentLoader(files)
|
||||||
documents = document_loader.all_data
|
documents = document_loader.all_data
|
||||||
all_files = []
|
all_files = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
assert isinstance(doc.page_content, str)==True
|
assert isinstance(doc.page_content, str) == True
|
||||||
if doc.metadata['source'] not in all_files:
|
if doc.metadata["source"] not in all_files:
|
||||||
all_files.append(doc.metadata['source'])
|
all_files.append(doc.metadata["source"])
|
||||||
print(all_files)
|
print(all_files)
|
||||||
assert len(all_files) == 6
|
assert len(all_files) == 6
|
||||||
|
|
||||||
|
|
||||||
if __name__=='__main__':
|
if __name__ == "__main__":
|
||||||
test_add_document()
|
test_add_document()
|
||||||
|
|
||||||
|
|
|
@ -4,56 +4,44 @@ from colossalqa.retrieval_conversation_universal import UniversalRetrievalConver
|
||||||
|
|
||||||
|
|
||||||
def test_en_retrievalQA():
|
def test_en_retrievalQA():
|
||||||
data_path_en = os.environ.get('TEST_DATA_PATH_EN')
|
data_path_en = os.environ.get("TEST_DATA_PATH_EN")
|
||||||
data_path_zh = os.environ.get('TEST_DATA_PATH_ZH')
|
data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
|
||||||
en_model_path = os.environ.get('EN_MODEL_PATH')
|
en_model_path = os.environ.get("EN_MODEL_PATH")
|
||||||
zh_model_path = os.environ.get('ZH_MODEL_PATH')
|
zh_model_path = os.environ.get("ZH_MODEL_PATH")
|
||||||
zh_model_name = os.environ.get('ZH_MODEL_NAME')
|
zh_model_name = os.environ.get("ZH_MODEL_NAME")
|
||||||
en_model_name = os.environ.get('EN_MODEL_NAME')
|
en_model_name = os.environ.get("EN_MODEL_NAME")
|
||||||
sql_file_path = os.environ.get('SQL_FILE_PATH')
|
sql_file_path = os.environ.get("SQL_FILE_PATH")
|
||||||
qa_session = UniversalRetrievalConversation(files_en=[{
|
qa_session = UniversalRetrievalConversation(
|
||||||
'data_path': data_path_en,
|
files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}],
|
||||||
'name': 'company information',
|
files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}],
|
||||||
'separator': '\n'
|
|
||||||
}],
|
|
||||||
files_zh=[{
|
|
||||||
'data_path': data_path_zh,
|
|
||||||
'name': 'company information',
|
|
||||||
'separator': '\n'
|
|
||||||
}],
|
|
||||||
zh_model_path=zh_model_path,
|
zh_model_path=zh_model_path,
|
||||||
en_model_path=en_model_path,
|
en_model_path=en_model_path,
|
||||||
zh_model_name=zh_model_name,
|
zh_model_name=zh_model_name,
|
||||||
en_model_name=en_model_name,
|
en_model_name=en_model_name,
|
||||||
sql_file_path=sql_file_path)
|
sql_file_path=sql_file_path,
|
||||||
ans = qa_session.run("which company runs business in hotel industry?", which_language='en')
|
)
|
||||||
|
ans = qa_session.run("which company runs business in hotel industry?", which_language="en")
|
||||||
print(ans)
|
print(ans)
|
||||||
|
|
||||||
|
|
||||||
def test_zh_retrievalQA():
|
def test_zh_retrievalQA():
|
||||||
data_path_en = os.environ.get('TEST_DATA_PATH_EN')
|
data_path_en = os.environ.get("TEST_DATA_PATH_EN")
|
||||||
data_path_zh = os.environ.get('TEST_DATA_PATH_ZH')
|
data_path_zh = os.environ.get("TEST_DATA_PATH_ZH")
|
||||||
en_model_path = os.environ.get('EN_MODEL_PATH')
|
en_model_path = os.environ.get("EN_MODEL_PATH")
|
||||||
zh_model_path = os.environ.get('ZH_MODEL_PATH')
|
zh_model_path = os.environ.get("ZH_MODEL_PATH")
|
||||||
zh_model_name = os.environ.get('ZH_MODEL_NAME')
|
zh_model_name = os.environ.get("ZH_MODEL_NAME")
|
||||||
en_model_name = os.environ.get('EN_MODEL_NAME')
|
en_model_name = os.environ.get("EN_MODEL_NAME")
|
||||||
sql_file_path = os.environ.get('SQL_FILE_PATH')
|
sql_file_path = os.environ.get("SQL_FILE_PATH")
|
||||||
qa_session = UniversalRetrievalConversation(files_en=[{
|
qa_session = UniversalRetrievalConversation(
|
||||||
'data_path': data_path_en,
|
files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}],
|
||||||
'name': 'company information',
|
files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}],
|
||||||
'separator': '\n'
|
|
||||||
}],
|
|
||||||
files_zh=[{
|
|
||||||
'data_path': data_path_zh,
|
|
||||||
'name': 'company information',
|
|
||||||
'separator': '\n'
|
|
||||||
}],
|
|
||||||
zh_model_path=zh_model_path,
|
zh_model_path=zh_model_path,
|
||||||
en_model_path=en_model_path,
|
en_model_path=en_model_path,
|
||||||
zh_model_name=zh_model_name,
|
zh_model_name=zh_model_name,
|
||||||
en_model_name=en_model_name,
|
en_model_name=en_model_name,
|
||||||
sql_file_path=sql_file_path)
|
sql_file_path=sql_file_path,
|
||||||
ans = qa_session.run("哪家公司在经营酒店业务?", which_language='zh')
|
)
|
||||||
|
ans = qa_session.run("哪家公司在经营酒店业务?", which_language="zh")
|
||||||
print(ans)
|
print(ans)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
|
|
||||||
from . import accelerator
|
from . import accelerator
|
||||||
|
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# .version will be created by setup.py
|
# .version will be created by setup.py
|
||||||
|
|
|
@ -27,7 +27,7 @@ from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile
|
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
|
||||||
|
@ -93,9 +93,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
|
|
||||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||||
with FSDP.state_dict_type(
|
with FSDP.state_dict_type(
|
||||||
model.unwrap(),
|
model.unwrap(), StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||||
StateDictType.FULL_STATE_DICT,
|
|
||||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
|
||||||
):
|
):
|
||||||
state_dict = model.unwrap().state_dict()
|
state_dict = model.unwrap().state_dict()
|
||||||
|
|
||||||
|
@ -172,7 +170,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
with FSDP.state_dict_type(
|
with FSDP.state_dict_type(
|
||||||
optimizer.unwrap_model().unwrap(),
|
optimizer.unwrap_model().unwrap(),
|
||||||
StateDictType.FULL_STATE_DICT,
|
StateDictType.FULL_STATE_DICT,
|
||||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
||||||
):
|
):
|
||||||
fsdp_optim_state = FSDP.full_optim_state_dict(
|
fsdp_optim_state = FSDP.full_optim_state_dict(
|
||||||
optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
|
optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
|
||||||
|
@ -241,7 +239,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
)
|
)
|
||||||
optimizer.load_state_dict(fsdp_state)
|
optimizer.load_state_dict(fsdp_state)
|
||||||
|
|
||||||
|
|
||||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
|
|
|
@ -294,6 +294,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
||||||
# Helper functions for saving state dict
|
# Helper functions for saving state dict
|
||||||
# ======================================
|
# ======================================
|
||||||
|
|
||||||
|
|
||||||
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
|
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Save state dict to checkpoint.
|
Save state dict to checkpoint.
|
||||||
|
|
|
@ -225,4 +225,3 @@ class ProcessGroupMesh:
|
||||||
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
||||||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
||||||
return self._ranks_to_group[ranks_in_group]
|
return self._ranks_to_group[ranks_in_group]
|
||||||
|
|
|
@ -29,13 +29,17 @@ except:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.triton.flash_decoding import token_flash_decoding
|
from colossalai.kernel.triton.flash_decoding import token_flash_decoding
|
||||||
|
|
||||||
HAS_TRITON_FLASH_DECODING_KERNEL = True
|
HAS_TRITON_FLASH_DECODING_KERNEL = True
|
||||||
except:
|
except:
|
||||||
print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
|
print(
|
||||||
|
"no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8"
|
||||||
|
)
|
||||||
HAS_TRITON_FLASH_DECODING_KERNEL = False
|
HAS_TRITON_FLASH_DECODING_KERNEL = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_with_kvcache
|
from flash_attn import flash_attn_with_kvcache
|
||||||
|
|
||||||
HAS_FLASH_KERNEL = True
|
HAS_FLASH_KERNEL = True
|
||||||
except:
|
except:
|
||||||
HAS_FLASH_KERNEL = False
|
HAS_FLASH_KERNEL = False
|
||||||
|
@ -48,6 +52,7 @@ def rotate_half(x):
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
|
@ -96,15 +101,20 @@ def llama_triton_context_attention(
|
||||||
infer_state.max_len_in_batch,
|
infer_state.max_len_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
|
|
||||||
|
def llama_triton_token_attention(
|
||||||
|
query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num=-1, head_dim=-1
|
||||||
|
):
|
||||||
if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
|
if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
|
||||||
token_flash_decoding(q = query_states,
|
token_flash_decoding(
|
||||||
o_tensor = attn_output,
|
q=query_states,
|
||||||
infer_state = infer_state,
|
o_tensor=attn_output,
|
||||||
q_head_num = q_head_num,
|
infer_state=infer_state,
|
||||||
head_dim = head_dim,
|
q_head_num=q_head_num,
|
||||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
head_dim=head_dim,
|
||||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id])
|
cache_k=infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||||
|
cache_v=infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if num_key_value_groups == 1:
|
if num_key_value_groups == 1:
|
||||||
|
@ -459,14 +469,15 @@ class LlamaInferenceForwards:
|
||||||
)
|
)
|
||||||
|
|
||||||
if HAS_LIGHTLLM_KERNEL:
|
if HAS_LIGHTLLM_KERNEL:
|
||||||
|
|
||||||
attn_output = torch.empty_like(query_states)
|
attn_output = torch.empty_like(query_states)
|
||||||
llama_triton_token_attention(query_states = query_states,
|
llama_triton_token_attention(
|
||||||
attn_output = attn_output,
|
query_states=query_states,
|
||||||
infer_state = infer_state,
|
attn_output=attn_output,
|
||||||
num_key_value_groups = self.num_key_value_groups,
|
infer_state=infer_state,
|
||||||
q_head_num = q_len * self.num_heads,
|
num_key_value_groups=self.num_key_value_groups,
|
||||||
head_dim = self.head_dim)
|
q_head_num=q_len * self.num_heads,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.num_heads // self.num_key_value_heads
|
self.num_heads // self.num_key_value_heads
|
||||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
|
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
|
||||||
|
|
|
@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp
|
||||||
HAS_GPTQ_CUDA = False
|
HAS_GPTQ_CUDA = False
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||||
|
|
||||||
gptq_cuda = GPTQBuilder().load()
|
gptq_cuda = GPTQBuilder().load()
|
||||||
HAS_GPTQ_CUDA = True
|
HAS_GPTQ_CUDA = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
warnings.warn('CUDA gptq is not installed')
|
warnings.warn("CUDA gptq is not installed")
|
||||||
HAS_GPTQ_CUDA = False
|
HAS_GPTQ_CUDA = False
|
||||||
|
|
||||||
|
|
||||||
class CaiQuantLinear(nn.Module):
|
class CaiQuantLinear(nn.Module):
|
||||||
|
|
||||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if bits not in [2, 4, 8]:
|
if bits not in [2, 4, 8]:
|
||||||
|
@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module):
|
||||||
self.maxq = 2**self.bits - 1
|
self.maxq = 2**self.bits - 1
|
||||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||||
|
|
||||||
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
'qzeros',
|
"qzeros",
|
||||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
|
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32),
|
||||||
self.register_buffer('scales',
|
)
|
||||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
|
self.register_buffer(
|
||||||
|
"scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
|
||||||
|
)
|
||||||
if row_split:
|
if row_split:
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
'g_idx',
|
"g_idx",
|
||||||
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
|
torch.tensor(
|
||||||
dtype=torch.int32))
|
[(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.register_buffer('g_idx',
|
self.register_buffer(
|
||||||
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
|
"g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module):
|
||||||
self.row_split = row_split
|
self.row_split = row_split
|
||||||
|
|
||||||
def pack(self, linear, scales, zeros, g_idx=None):
|
def pack(self, linear, scales, zeros, g_idx=None):
|
||||||
|
g_idx = (
|
||||||
g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
|
g_idx.clone()
|
||||||
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
|
if g_idx is not None
|
||||||
|
else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
scales = scales.t().contiguous()
|
scales = scales.t().contiguous()
|
||||||
zeros = zeros.t().contiguous()
|
zeros = zeros.t().contiguous()
|
||||||
|
@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module):
|
||||||
if linear.bias is not None:
|
if linear.bias is not None:
|
||||||
self.bias = linear.bias.clone().half()
|
self.bias = linear.bias.clone().half()
|
||||||
|
|
||||||
wn = 8
|
|
||||||
pbits = 32
|
pbits = 32
|
||||||
ptype = torch.int32
|
ptype = torch.int32
|
||||||
unsign_type = np.uint32
|
unsign_type = np.uint32
|
||||||
|
@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module):
|
||||||
intweight = []
|
intweight = []
|
||||||
for idx in range(self.infeatures):
|
for idx in range(self.infeatures):
|
||||||
intweight.append(
|
intweight.append(
|
||||||
torch.round(
|
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[
|
||||||
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
|
:, None
|
||||||
None])
|
]
|
||||||
|
)
|
||||||
intweight = torch.cat(intweight, dim=1)
|
intweight = torch.cat(intweight, dim=1)
|
||||||
intweight = intweight.t().contiguous()
|
intweight = intweight.t().contiguous()
|
||||||
intweight = intweight.numpy().astype(unsign_type)
|
intweight = intweight.numpy().astype(unsign_type)
|
||||||
|
@ -109,7 +116,7 @@ class CaiQuantLinear(nn.Module):
|
||||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||||
qweight = qweight.astype(sign_type)
|
qweight = qweight.astype(sign_type)
|
||||||
qweight1 = torch.from_numpy(qweight)
|
qweight1 = torch.from_numpy(qweight)
|
||||||
qweight1 = qweight1.contiguous() #.to("cuda")
|
qweight1 = qweight1.contiguous() # .to("cuda")
|
||||||
self.qweight.data.copy_(qweight1)
|
self.qweight.data.copy_(qweight1)
|
||||||
|
|
||||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
|
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
|
||||||
|
@ -144,13 +151,16 @@ class CaiQuantLinear(nn.Module):
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
|
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.g_idx.device)):
|
device=self.g_idx.device,
|
||||||
|
),
|
||||||
|
):
|
||||||
self.g_idx = None
|
self.g_idx = None
|
||||||
elif torch.equal(
|
elif torch.equal(
|
||||||
self.g_idx,
|
self.g_idx,
|
||||||
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
|
torch.tensor(
|
||||||
dtype=torch.int32,
|
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device
|
||||||
device=self.g_idx.device)):
|
),
|
||||||
|
):
|
||||||
self.g_idx = None
|
self.g_idx = None
|
||||||
|
|
||||||
if self.g_idx is not None:
|
if self.g_idx is not None:
|
||||||
|
@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module):
|
||||||
outshape = x.shape[:-1] + (self.outfeatures,)
|
outshape = x.shape[:-1] + (self.outfeatures,)
|
||||||
|
|
||||||
if HAS_GPTQ_CUDA and self.bits == 4:
|
if HAS_GPTQ_CUDA and self.bits == 4:
|
||||||
|
|
||||||
if self.q4 is None:
|
if self.q4 is None:
|
||||||
self.init_q4()
|
self.init_q4()
|
||||||
|
|
||||||
|
@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
|
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
|
||||||
|
|
||||||
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
|
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
|
||||||
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
|
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
|
||||||
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
|
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
|
||||||
|
@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1
|
||||||
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
|
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
|
||||||
|
|
||||||
for i in range(split_num):
|
for i in range(split_num):
|
||||||
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
|
cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][
|
||||||
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||||
cai_split_out_features]
|
]
|
||||||
cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
|
cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][
|
||||||
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
|
:, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block
|
||||||
cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
|
]
|
||||||
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][
|
||||||
cai_split_out_features]
|
:, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||||
|
]
|
||||||
if cai_linear.bias is not None:
|
if cai_linear.bias is not None:
|
||||||
cai_linear.bias[i * cai_split_out_features:(i + 1) *
|
cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][
|
||||||
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
|
tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
|
||||||
cai_split_out_features]
|
]
|
||||||
|
|
||||||
cai_linear.g_idx.copy_(g_idx)
|
cai_linear.g_idx.copy_(g_idx)
|
||||||
|
|
||||||
|
|
||||||
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
||||||
|
|
||||||
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
|
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
|
||||||
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
|
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
|
||||||
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
|
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
|
||||||
|
@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
||||||
idx_split_features = cai_linear.infeatures // split_num
|
idx_split_features = cai_linear.infeatures // split_num
|
||||||
|
|
||||||
for i in range(split_num):
|
for i in range(split_num):
|
||||||
cai_linear.qweight[i * cai_split_in_features:(i + 1) *
|
cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][
|
||||||
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
|
tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, :
|
||||||
cai_split_in_features, :]
|
]
|
||||||
cai_linear.qzeros[i * zero_split_block:(i + 1) *
|
cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][
|
||||||
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
|
||||||
zero_split_block, :]
|
]
|
||||||
cai_linear.scales[i * zero_split_block:(i + 1) *
|
cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][
|
||||||
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
|
||||||
zero_split_block, :]
|
]
|
||||||
cai_linear.g_idx[i * idx_split_features:(i + 1) *
|
cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][
|
||||||
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
|
tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features
|
||||||
idx_split_features]
|
]
|
||||||
if cai_linear.bias is not None:
|
if cai_linear.bias is not None:
|
||||||
cai_linear.bias.copy_(gptq_linear.bias)
|
cai_linear.bias.copy_(gptq_linear.bias)
|
||||||
|
|
||||||
|
|
||||||
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||||
|
|
||||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||||
|
super().__init__(
|
||||||
super().__init__(bits,
|
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
|
||||||
groupsize,
|
)
|
||||||
infeatures,
|
|
||||||
outfeatures,
|
|
||||||
bias,
|
|
||||||
tp_size=tp_size,
|
|
||||||
tp_rank=tp_rank,
|
|
||||||
row_split=row_split)
|
|
||||||
self.process_group = None
|
self.process_group = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
def from_native_module(
|
||||||
**kwargs) -> ParallelModule:
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||||
|
) -> ParallelModule:
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes
|
# get the attributes
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
|
|
||||||
# ensure only one process group is passed
|
# ensure only one process group is passed
|
||||||
if isinstance(process_group, (list, tuple)):
|
if isinstance(process_group, (list, tuple)):
|
||||||
assert len(process_group) == 1, \
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
f'Expected only one process group, got {len(process_group)}.'
|
|
||||||
process_group = process_group[0]
|
process_group = process_group[0]
|
||||||
|
|
||||||
tp_size = dist.get_world_size(process_group)
|
tp_size = dist.get_world_size(process_group)
|
||||||
|
@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||||
|
|
||||||
if in_features % tp_size != 0:
|
if in_features % tp_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||||
linear_1d = RowCaiQuantLinear(module.bits,
|
)
|
||||||
|
linear_1d = RowCaiQuantLinear(
|
||||||
|
module.bits,
|
||||||
module.group_size,
|
module.group_size,
|
||||||
module.in_features // tp_size,
|
module.in_features // tp_size,
|
||||||
module.out_features,
|
module.out_features,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
row_split=True)
|
row_split=True,
|
||||||
|
)
|
||||||
linear_1d.process_group = process_group
|
linear_1d.process_group = process_group
|
||||||
|
|
||||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||||
|
@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||||
|
|
||||||
|
|
||||||
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||||
|
|
||||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||||
|
super().__init__(
|
||||||
super().__init__(bits,
|
bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
|
||||||
groupsize,
|
)
|
||||||
infeatures,
|
|
||||||
outfeatures,
|
|
||||||
bias,
|
|
||||||
tp_size=tp_size,
|
|
||||||
tp_rank=tp_rank,
|
|
||||||
row_split=row_split)
|
|
||||||
self.process_group = None
|
self.process_group = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
def from_native_module(
|
||||||
**kwargs) -> ParallelModule:
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||||
|
) -> ParallelModule:
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes
|
# get the attributes
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
|
|
||||||
# ensure only one process group is passed
|
# ensure only one process group is passed
|
||||||
if isinstance(process_group, (list, tuple)):
|
if isinstance(process_group, (list, tuple)):
|
||||||
assert len(process_group) == 1, \
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
f'Expected only one process group, got {len(process_group)}.'
|
|
||||||
process_group = process_group[0]
|
process_group = process_group[0]
|
||||||
|
|
||||||
tp_size = dist.get_world_size(process_group)
|
tp_size = dist.get_world_size(process_group)
|
||||||
|
@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||||
|
|
||||||
if in_features % tp_size != 0:
|
if in_features % tp_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||||
linear_1d = ColCaiQuantLinear(module.bits,
|
)
|
||||||
|
linear_1d = ColCaiQuantLinear(
|
||||||
|
module.bits,
|
||||||
module.group_size,
|
module.group_size,
|
||||||
module.in_features,
|
module.in_features,
|
||||||
module.out_features // tp_size,
|
module.out_features // tp_size,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
tp_rank=tp_rank)
|
tp_rank=tp_rank,
|
||||||
|
)
|
||||||
linear_1d.process_group = process_group
|
linear_1d.process_group = process_group
|
||||||
|
|
||||||
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
||||||
try:
|
try:
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
HAS_TRITON = True
|
HAS_TRITON = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_TRITON = False
|
HAS_TRITON = False
|
||||||
|
@ -16,6 +17,7 @@ if HAS_TRITON:
|
||||||
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||||
"""
|
"""
|
||||||
if triton.__version__ < "2.1.0":
|
if triton.__version__ < "2.1.0":
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _context_flash_attention_kernel(
|
def _context_flash_attention_kernel(
|
||||||
Q,
|
Q,
|
||||||
|
@ -131,23 +133,41 @@ if HAS_TRITON:
|
||||||
m_i = m_i_new
|
m_i = m_i_new
|
||||||
|
|
||||||
off_o = (
|
off_o = (
|
||||||
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
(cur_batch_start_index + offs_m[:, None]) * stride_obs
|
||||||
|
+ cur_head * stride_oh
|
||||||
|
+ offs_d[None, :] * stride_od
|
||||||
)
|
)
|
||||||
out_ptrs = Out + off_o
|
out_ptrs = Out + off_o
|
||||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||||
return
|
return
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
|
# this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _context_flash_attention_kernel_2(
|
def _context_flash_attention_kernel_2(
|
||||||
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
sm_scale,
|
||||||
|
Alibi,
|
||||||
|
B_Start_Loc,
|
||||||
|
B_Seqlen,
|
||||||
Out,
|
Out,
|
||||||
kv_group_num,
|
kv_group_num,
|
||||||
stride_qbs, stride_qh, stride_qd,
|
stride_qbs,
|
||||||
stride_kbs, stride_kh, stride_kd,
|
stride_qh,
|
||||||
stride_vbs, stride_vh, stride_vd,
|
stride_qd,
|
||||||
stride_obs, stride_oh, stride_od,
|
stride_kbs,
|
||||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
stride_kh,
|
||||||
|
stride_kd,
|
||||||
|
stride_vbs,
|
||||||
|
stride_vh,
|
||||||
|
stride_vd,
|
||||||
|
stride_obs,
|
||||||
|
stride_oh,
|
||||||
|
stride_od,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
|
@ -166,7 +186,11 @@ if HAS_TRITON:
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
|
off_q = (
|
||||||
|
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||||
|
+ cur_head * stride_qh
|
||||||
|
+ offs_d[None, :] * stride_qd
|
||||||
|
)
|
||||||
if kv_group_num is None or kv_group_num == 1:
|
if kv_group_num is None or kv_group_num == 1:
|
||||||
off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||||
off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||||
|
@ -191,8 +215,11 @@ if HAS_TRITON:
|
||||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
k = tl.load(
|
||||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
|
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||||
|
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
|
@ -220,8 +247,11 @@ if HAS_TRITON:
|
||||||
acc_scale = l_i / l_i_new * alpha
|
acc_scale = l_i / l_i_new * alpha
|
||||||
acc = acc * acc_scale[:, None]
|
acc = acc * acc_scale[:, None]
|
||||||
# update acc
|
# update acc
|
||||||
v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
v = tl.load(
|
||||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
|
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||||
|
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
acc += tl.dot(p, v)
|
acc += tl.dot(p, v)
|
||||||
|
@ -229,7 +259,11 @@ if HAS_TRITON:
|
||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
m_i = m_i_new
|
m_i = m_i_new
|
||||||
# initialize pointers to output
|
# initialize pointers to output
|
||||||
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
off_o = (
|
||||||
|
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||||
|
+ cur_head * stride_oh
|
||||||
|
+ offs_d[None, :] * stride_od
|
||||||
|
)
|
||||||
out_ptrs = Out + off_o
|
out_ptrs = Out + off_o
|
||||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||||
return
|
return
|
||||||
|
@ -286,7 +320,13 @@ if HAS_TRITON:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_context_flash_attention_kernel_2[grid](
|
_context_flash_attention_kernel_2[grid](
|
||||||
q, k, v, sm_scale, alibi, b_start_loc, b_seq_len,
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
sm_scale,
|
||||||
|
alibi,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
o,
|
o,
|
||||||
None,
|
None,
|
||||||
q.stride(0),
|
q.stride(0),
|
||||||
|
@ -388,6 +428,7 @@ if HAS_TRITON:
|
||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=Lk,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,)
|
num_stages=1,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
|
@ -1,8 +1,10 @@
|
||||||
# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
|
# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
|
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
|
||||||
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
|
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
|
||||||
|
|
||||||
HAS_LIGHTLLM_KERNEL = True
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
except:
|
except:
|
||||||
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
|
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
|
||||||
|
@ -10,31 +12,29 @@ except:
|
||||||
|
|
||||||
|
|
||||||
if HAS_LIGHTLLM_KERNEL:
|
if HAS_LIGHTLLM_KERNEL:
|
||||||
|
|
||||||
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
|
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
|
||||||
BLOCK_SEQ = 256
|
BLOCK_SEQ = 256
|
||||||
batch_size = infer_state.batch_size
|
batch_size = infer_state.batch_size
|
||||||
max_len_in_batch = infer_state.max_len_in_batch
|
max_len_in_batch = infer_state.max_len_in_batch
|
||||||
|
|
||||||
|
|
||||||
calcu_shape1 = (batch_size, q_head_num, head_dim)
|
calcu_shape1 = (batch_size, q_head_num, head_dim)
|
||||||
|
|
||||||
if getattr(infer_state, 'mid_o', None) is None:
|
if getattr(infer_state, "mid_o", None) is None:
|
||||||
infer_state.mid_o = torch.empty([batch_size,
|
infer_state.mid_o = torch.empty(
|
||||||
q_head_num,
|
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim],
|
||||||
max_len_in_batch // BLOCK_SEQ + 1,
|
|
||||||
head_dim],
|
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device="cuda")
|
device="cuda",
|
||||||
infer_state.mid_o_logexpsum = torch.empty([batch_size,
|
)
|
||||||
q_head_num,
|
infer_state.mid_o_logexpsum = torch.empty(
|
||||||
max_len_in_batch // BLOCK_SEQ + 1],
|
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
|
||||||
dtype=torch.float32,
|
)
|
||||||
device="cuda")
|
|
||||||
|
|
||||||
mid_o = infer_state.mid_o
|
mid_o = infer_state.mid_o
|
||||||
mid_o_logexpsum = infer_state.mid_o_logexpsum
|
mid_o_logexpsum = infer_state.mid_o_logexpsum
|
||||||
|
|
||||||
flash_decode_stage1(q.view(calcu_shape1),
|
flash_decode_stage1(
|
||||||
|
q.view(calcu_shape1),
|
||||||
cache_k,
|
cache_k,
|
||||||
cache_v,
|
cache_v,
|
||||||
infer_state.block_loc,
|
infer_state.block_loc,
|
||||||
|
@ -42,9 +42,6 @@ if HAS_LIGHTLLM_KERNEL:
|
||||||
infer_state.max_len_in_batch,
|
infer_state.max_len_in_batch,
|
||||||
mid_o,
|
mid_o,
|
||||||
mid_o_logexpsum,
|
mid_o_logexpsum,
|
||||||
BLOCK_SEQ)
|
BLOCK_SEQ,
|
||||||
flash_decode_stage2(mid_o,
|
)
|
||||||
mid_o_logexpsum,
|
flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
|
||||||
infer_state.seq_len,
|
|
||||||
o_tensor.view(calcu_shape1),
|
|
||||||
BLOCK_SEQ)
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
try:
|
try:
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
HAS_TRITON = True
|
HAS_TRITON = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_TRITON = False
|
HAS_TRITON = False
|
||||||
|
@ -41,9 +42,9 @@ if HAS_TRITON:
|
||||||
for off in range(0, N, BLOCK_SIZE):
|
for off in range(0, N, BLOCK_SIZE):
|
||||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||||
mask = cols < N
|
mask = cols < N
|
||||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
|
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
|
||||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
|
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
|
||||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
|
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
|
||||||
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
||||||
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
|
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
|
||||||
# Write output
|
# Write output
|
||||||
|
@ -76,10 +77,10 @@ if HAS_TRITON:
|
||||||
for off in range(0, N, BLOCK_SIZE):
|
for off in range(0, N, BLOCK_SIZE):
|
||||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||||
mask = cols < N
|
mask = cols < N
|
||||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
|
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
|
||||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
|
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
|
||||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
|
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
|
||||||
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)
|
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.0)
|
||||||
|
|
||||||
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
|
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
|
||||||
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
||||||
|
@ -147,14 +148,9 @@ if HAS_TRITON:
|
||||||
# restore setting
|
# restore setting
|
||||||
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
|
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
|
||||||
# enqueue kernel
|
# enqueue kernel
|
||||||
_llama_act_combine_forward[(M,)](x_gate1,
|
_llama_act_combine_forward[(M,)](
|
||||||
x_gate2,
|
x_gate1, x_gate2, x_up, y, x_up.stride(-2), N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
||||||
x_up,
|
)
|
||||||
y,
|
|
||||||
x_up.stride(-2),
|
|
||||||
N,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
num_warps=num_warps)
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -166,11 +162,15 @@ if HAS_TRITON:
|
||||||
|
|
||||||
# init grad
|
# init grad
|
||||||
y_grad = grad_outputs[0]
|
y_grad = grad_outputs[0]
|
||||||
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
|
x_gate1_grad, x_gate2_grad, x_up_grad = (
|
||||||
x_gate2), torch.empty_like(x_up)
|
torch.empty_like(x_gate1),
|
||||||
|
torch.empty_like(x_gate2),
|
||||||
|
torch.empty_like(x_up),
|
||||||
|
)
|
||||||
|
|
||||||
# enqueue kernel
|
# enqueue kernel
|
||||||
_llama_act_combine_backward[(M,)](x_gate1,
|
_llama_act_combine_backward[(M,)](
|
||||||
|
x_gate1,
|
||||||
x_gate2,
|
x_gate2,
|
||||||
x_up,
|
x_up,
|
||||||
x_gate1_grad,
|
x_gate1_grad,
|
||||||
|
@ -180,6 +180,7 @@ if HAS_TRITON:
|
||||||
x_up.stride(-2),
|
x_up.stride(-2),
|
||||||
N,
|
N,
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
num_warps=num_warps)
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
|
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
|
||||||
return x_gate_grad, x_up_grad, None, None
|
return x_gate_grad, x_up_grad, None, None
|
||||||
|
|
|
@ -13,10 +13,18 @@ except ImportError:
|
||||||
print("please install triton from https://github.com/openai/triton")
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2
|
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import (
|
||||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
|
token_att_fwd as lightllm_bloom_token_att_fwd,
|
||||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
|
)
|
||||||
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import (
|
||||||
|
token_att_fwd as lightllm_llama_token_att_fwd,
|
||||||
|
)
|
||||||
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import (
|
||||||
|
token_att_fwd2 as lightllm_llama_token_att_fwd2,
|
||||||
|
)
|
||||||
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import (
|
||||||
|
token_softmax_fwd as lightllm_llama_token_softmax_fwd,
|
||||||
|
)
|
||||||
|
|
||||||
HAS_TRITON_TOKEN_ATTENTION = True
|
HAS_TRITON_TOKEN_ATTENTION = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -205,9 +213,7 @@ class Llama2TokenAttentionForwards:
|
||||||
|
|
||||||
if triton.__version__ == "2.0.0":
|
if triton.__version__ == "2.0.0":
|
||||||
prob = torch.empty_like(att_m_tensor)
|
prob = torch.empty_like(att_m_tensor)
|
||||||
lightllm_llama_token_softmax_fwd(
|
lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
||||||
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
|
|
||||||
)
|
|
||||||
att_m_tensor = None
|
att_m_tensor = None
|
||||||
|
|
||||||
lightllm_llama_token_att_fwd2(
|
lightllm_llama_token_att_fwd2(
|
||||||
|
|
|
@ -8,7 +8,9 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
|
||||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
||||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||||
|
|
||||||
from ._utils import copy_kv_to_mem_cache
|
from ._utils import copy_kv_to_mem_cache
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
|
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
|
||||||
context_attention_fwd as lightllm_llama_context_attention_fwd,
|
context_attention_fwd as lightllm_llama_context_attention_fwd,
|
||||||
|
|
|
@ -44,7 +44,7 @@ class Qwen2PipelineForwards:
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
)-> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
|
|
@ -2,13 +2,13 @@ from .api import (
|
||||||
compute_global_numel,
|
compute_global_numel,
|
||||||
customized_distributed_tensor_to_param,
|
customized_distributed_tensor_to_param,
|
||||||
distribute_tensor,
|
distribute_tensor,
|
||||||
init_as_dtensor,
|
|
||||||
distribute_tensor_with_customization,
|
distribute_tensor_with_customization,
|
||||||
init_tensor_as_customization_distributed,
|
|
||||||
get_device_mesh,
|
get_device_mesh,
|
||||||
get_global_shape,
|
get_global_shape,
|
||||||
get_layout,
|
get_layout,
|
||||||
get_sharding_spec,
|
get_sharding_spec,
|
||||||
|
init_as_dtensor,
|
||||||
|
init_tensor_as_customization_distributed,
|
||||||
is_customized_distributed_tensor,
|
is_customized_distributed_tensor,
|
||||||
is_distributed_tensor,
|
is_distributed_tensor,
|
||||||
is_sharded,
|
is_sharded,
|
||||||
|
|
|
@ -128,7 +128,10 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
|
||||||
|
|
||||||
return sharded_tensor
|
return sharded_tensor
|
||||||
|
|
||||||
def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor:
|
|
||||||
|
def init_as_dtensor(
|
||||||
|
tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size
|
||||||
|
) -> torch.Tensor:
|
||||||
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
||||||
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
||||||
|
|
||||||
|
@ -140,6 +143,7 @@ def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec
|
||||||
|
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
||||||
"""
|
"""
|
||||||
Convert the layout of the tensor from source_spec to target_spec.
|
Convert the layout of the tensor from source_spec to target_spec.
|
||||||
|
@ -468,7 +472,6 @@ def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gat
|
||||||
assert callable(gather_fn), "The gather_fn must be callable."
|
assert callable(gather_fn), "The gather_fn must be callable."
|
||||||
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
||||||
|
|
||||||
|
|
||||||
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
||||||
tensor.shard_fn = shard_fn
|
tensor.shard_fn = shard_fn
|
||||||
tensor.gather_fn = gather_fn
|
tensor.gather_fn = gather_fn
|
||||||
|
|
|
@ -119,9 +119,7 @@ def main():
|
||||||
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
|
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
|
||||||
# run pipeline forward backward
|
# run pipeline forward backward
|
||||||
batch = iter([batch])
|
batch = iter([batch])
|
||||||
outputs = booster.execute_pipeline(
|
outputs = booster.execute_pipeline(batch, model, criterion, optimizer, return_loss=True)
|
||||||
batch, model, criterion, optimizer, return_loss=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
loss = criterion(outputs, None)
|
loss = criterion(outputs, None)
|
||||||
|
|
|
@ -270,9 +270,7 @@ def main():
|
||||||
) as pbar:
|
) as pbar:
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
outputs = booster.execute_pipeline(
|
outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
|
||||||
dataloader_iter, model, _criterion, optimizer, return_loss=True
|
|
||||||
)
|
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
else:
|
else:
|
||||||
batch = next(dataloader_iter)
|
batch = next(dataloader_iter)
|
||||||
|
|
|
@ -285,9 +285,7 @@ def main():
|
||||||
) as pbar:
|
) as pbar:
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
outputs = booster.execute_pipeline(
|
outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
|
||||||
dataloader_iter, model, _criterion, optimizer, return_loss=True
|
|
||||||
)
|
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
else:
|
else:
|
||||||
batch = next(dataloader_iter)
|
batch = next(dataloader_iter)
|
||||||
|
|
|
@ -50,7 +50,6 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.start_time: Optional[float] = None
|
self.start_time: Optional[float] = None
|
||||||
self.duration: float = 0.0
|
self.duration: float = 0.0
|
||||||
|
@ -112,7 +111,7 @@ class PerformanceEvaluator:
|
||||||
batch_size, seq_len = input_ids.shape
|
batch_size, seq_len = input_ids.shape
|
||||||
|
|
||||||
self.num_samples += batch_size
|
self.num_samples += batch_size
|
||||||
self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)))
|
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||||
|
|
||||||
def on_fit_end(self) -> None:
|
def on_fit_end(self) -> None:
|
||||||
avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
|
avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
|
||||||
|
@ -122,5 +121,6 @@ class PerformanceEvaluator:
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
print(
|
print(
|
||||||
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
||||||
f"avg_throughput: {avg_throughput}")
|
f"avg_throughput: {avg_throughput}"
|
||||||
|
)
|
||||||
print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")
|
print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")
|
||||||
|
|
|
@ -16,17 +16,15 @@ def inference(args):
|
||||||
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
||||||
if args.model == "test":
|
if args.model == "test":
|
||||||
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")
|
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")
|
||||||
set_openmoe_args(config,
|
set_openmoe_args(
|
||||||
num_experts=config.num_experts,
|
config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=True
|
||||||
moe_layer_interval=config.moe_layer_interval,
|
)
|
||||||
enable_kernel=True)
|
|
||||||
model = OpenMoeForCausalLM(config)
|
model = OpenMoeForCausalLM(config)
|
||||||
else:
|
else:
|
||||||
config = LlamaConfig.from_pretrained(f"hpcai-tech/openmoe-{args.model}")
|
config = LlamaConfig.from_pretrained(f"hpcai-tech/openmoe-{args.model}")
|
||||||
set_openmoe_args(config,
|
set_openmoe_args(
|
||||||
num_experts=config.num_experts,
|
config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=False
|
||||||
moe_layer_interval=config.moe_layer_interval,
|
)
|
||||||
enable_kernel=False)
|
|
||||||
model = OpenMoeForCausalLM.from_pretrained(f"hpcai-tech/openmoe-{args.model}", config=config)
|
model = OpenMoeForCausalLM.from_pretrained(f"hpcai-tech/openmoe-{args.model}", config=config)
|
||||||
model = model.eval().bfloat16()
|
model = model.eval().bfloat16()
|
||||||
model = model.to(torch.cuda.current_device())
|
model = model.to(torch.cuda.current_device())
|
||||||
|
|
|
@ -172,9 +172,9 @@ def make_state_dict(converted_params):
|
||||||
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
|
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
|
||||||
"""Replaces the params in model witht the T5X converted params."""
|
"""Replaces the params in model witht the T5X converted params."""
|
||||||
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
|
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
|
||||||
converted = convert_t5x_to_pytorch(variables,
|
converted = convert_t5x_to_pytorch(
|
||||||
num_layers=config.num_hidden_layers,
|
variables, num_layers=config.num_hidden_layers, moe_interval=config.moe_layer_interval
|
||||||
moe_interval=config.moe_layer_interval)
|
)
|
||||||
state_dict = make_state_dict(converted)
|
state_dict = make_state_dict(converted)
|
||||||
model.load_state_dict(state_dict, strict=True)
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
|
||||||
|
@ -203,11 +203,9 @@ def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
|
parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
|
||||||
# Required parameters
|
# Required parameters
|
||||||
parser.add_argument("--t5x_checkpoint_path",
|
parser.add_argument(
|
||||||
default=None,
|
"--t5x_checkpoint_path", default=None, type=str, required=True, help="Path to the T5X checkpoint."
|
||||||
type=str,
|
)
|
||||||
required=True,
|
|
||||||
help="Path to the T5X checkpoint.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_file",
|
"--config_file",
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -215,10 +213,8 @@ if __name__ == "__main__":
|
||||||
required=True,
|
required=True,
|
||||||
help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
|
help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--pytorch_dump_path",
|
parser.add_argument(
|
||||||
default=None,
|
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
type=str,
|
)
|
||||||
required=True,
|
|
||||||
help="Path to the output PyTorch model.")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
||||||
|
|
|
@ -41,9 +41,7 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b
|
||||||
# Forward pass
|
# Forward pass
|
||||||
for _ in pbar:
|
for _ in pbar:
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
outputs = booster.execute_pipeline(
|
outputs = booster.execute_pipeline(dataloader, model, _criterion, optimizer, return_loss=True)
|
||||||
dataloader, model, _criterion, optimizer, return_loss=True
|
|
||||||
)
|
|
||||||
# Backward and optimize
|
# Backward and optimize
|
||||||
if is_pp_last_stage:
|
if is_pp_last_stage:
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from .cpu_adam_arm import CpuAdamArmExtension
|
from .cpu_adam_arm import CpuAdamArmExtension
|
||||||
from .cpu_adam_x86 import CpuAdamX86Extension
|
from .cpu_adam_x86 import CpuAdamX86Extension
|
||||||
|
|
||||||
__all__ = ['CpuAdamArmExtension', 'CpuAdamX86Extension']
|
__all__ = ["CpuAdamArmExtension", "CpuAdamX86Extension"]
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
from .moe_cuda import MoeCudaExtension
|
from .moe_cuda import MoeCudaExtension
|
||||||
|
|
||||||
__all__ = ['MoeCudaExtension']
|
__all__ = ["MoeCudaExtension"]
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
from .fused_optimizer_cuda import FusedOptimizerCudaExtension
|
from .fused_optimizer_cuda import FusedOptimizerCudaExtension
|
||||||
|
|
||||||
__all__ = ['FusedOptimizerCudaExtension']
|
__all__ = ["FusedOptimizerCudaExtension"]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension
|
from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension
|
||||||
from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension
|
from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension
|
||||||
|
|
||||||
__all__ = ['ScaledMaskedSoftmaxCudaExtension', 'ScaledUpperTriangleMaskedSoftmaxCudaExtension']
|
__all__ = ["ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension"]
|
||||||
|
|
|
@ -1,33 +1,33 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from . import custom, diffusers, timm, torchaudio, torchvision, transformers
|
from . import custom, diffusers, timm, torchaudio, torchvision, transformers
|
||||||
from .executor import run_fwd, run_fwd_bwd
|
from .executor import run_fwd, run_fwd_bwd
|
||||||
from .registry import model_zoo
|
from .registry import model_zoo
|
||||||
|
|
||||||
# We pick a subset of models for fast testing in order to reduce the total testing time
|
# We pick a subset of models for fast testing in order to reduce the total testing time
|
||||||
COMMON_MODELS = [
|
COMMON_MODELS = [
|
||||||
'custom_hanging_param_model',
|
"custom_hanging_param_model",
|
||||||
'custom_nested_model',
|
"custom_nested_model",
|
||||||
'custom_repeated_computed_layers',
|
"custom_repeated_computed_layers",
|
||||||
'custom_simple_net',
|
"custom_simple_net",
|
||||||
'diffusers_clip_text_model',
|
"diffusers_clip_text_model",
|
||||||
'diffusers_auto_encoder_kl',
|
"diffusers_auto_encoder_kl",
|
||||||
'diffusers_unet2d_model',
|
"diffusers_unet2d_model",
|
||||||
'timm_densenet',
|
"timm_densenet",
|
||||||
'timm_resnet',
|
"timm_resnet",
|
||||||
'timm_swin_transformer',
|
"timm_swin_transformer",
|
||||||
'torchaudio_wav2vec2_base',
|
"torchaudio_wav2vec2_base",
|
||||||
'torchaudio_conformer',
|
"torchaudio_conformer",
|
||||||
'transformers_bert_for_masked_lm',
|
"transformers_bert_for_masked_lm",
|
||||||
'transformers_bloom_for_causal_lm',
|
"transformers_bloom_for_causal_lm",
|
||||||
'transformers_falcon_for_causal_lm',
|
"transformers_falcon_for_causal_lm",
|
||||||
'transformers_chatglm_for_conditional_generation',
|
"transformers_chatglm_for_conditional_generation",
|
||||||
'transformers_llama_for_casual_lm',
|
"transformers_llama_for_casual_lm",
|
||||||
'transformers_vit_for_masked_image_modeling',
|
"transformers_vit_for_masked_image_modeling",
|
||||||
'transformers_mistral_for_casual_lm'
|
"transformers_mistral_for_casual_lm",
|
||||||
]
|
]
|
||||||
|
|
||||||
IS_FAST_TEST = os.environ.get('FAST_TEST', '0') == '1'
|
IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", 'COMMON_MODELS', 'IS_FAST_TEST']
|
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", "COMMON_MODELS", "IS_FAST_TEST"]
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
|
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||||
|
|
||||||
from ..registry import ModelAttribute, model_zoo
|
from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
# ================================
|
# ================================
|
||||||
|
|
|
@ -74,9 +74,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
model.train()
|
model.train()
|
||||||
if booster.plugin.stage_manager is not None:
|
if booster.plugin.stage_manager is not None:
|
||||||
booster.execute_pipeline(
|
booster.execute_pipeline(_preprocess_data(data), model, _criterion, optimizer, return_loss=True)
|
||||||
_preprocess_data(data), model, _criterion, optimizer, return_loss=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
output = model(**_preprocess_data(data))
|
output = model(**_preprocess_data(data))
|
||||||
loss = criterion(output)
|
loss = criterion(output)
|
||||||
|
@ -108,9 +106,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
||||||
data_for_shard = data_gen_fn()
|
data_for_shard = data_gen_fn()
|
||||||
data_for_origin = data_gen_fn()
|
data_for_origin = data_gen_fn()
|
||||||
if booster.plugin.stage_manager is not None:
|
if booster.plugin.stage_manager is not None:
|
||||||
booster.execute_pipeline(
|
booster.execute_pipeline(_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True)
|
||||||
_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True
|
|
||||||
)
|
|
||||||
booster.execute_pipeline(
|
booster.execute_pipeline(
|
||||||
_preprocess_data(data_for_origin),
|
_preprocess_data(data_for_origin),
|
||||||
new_model,
|
new_model,
|
||||||
|
|
|
@ -113,6 +113,7 @@ def check_torch_fsdp_ckpt():
|
||||||
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
sharded_osd = copy.deepcopy(full_osd)
|
sharded_osd = copy.deepcopy(full_osd)
|
||||||
|
|
||||||
run_model()
|
run_model()
|
||||||
|
|
|
@ -1,16 +1,8 @@
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import transformers
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
HAS_TRITON = True
|
HAS_TRITON = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_TRITON = False
|
HAS_TRITON = False
|
||||||
|
@ -22,6 +14,7 @@ try:
|
||||||
from exllama_kernels import prepare_buffers, set_tuning_params
|
from exllama_kernels import prepare_buffers, set_tuning_params
|
||||||
|
|
||||||
from colossalai.inference.quant.gptq import CaiQuantLinear
|
from colossalai.inference.quant.gptq import CaiQuantLinear
|
||||||
|
|
||||||
HAS_AUTO_GPTQ = True
|
HAS_AUTO_GPTQ = True
|
||||||
except:
|
except:
|
||||||
HAS_AUTO_GPTQ = False
|
HAS_AUTO_GPTQ = False
|
||||||
|
@ -32,13 +25,14 @@ import warnings
|
||||||
HAS_GPTQ_CUDA = False
|
HAS_GPTQ_CUDA = False
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||||
|
|
||||||
gptq_cuda = GPTQBuilder().load()
|
gptq_cuda = GPTQBuilder().load()
|
||||||
HAS_GPTQ_CUDA = True
|
HAS_GPTQ_CUDA = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
warnings.warn('CUDA gptq is not installed')
|
warnings.warn("CUDA gptq is not installed")
|
||||||
HAS_GPTQ_CUDA = False
|
HAS_GPTQ_CUDA = False
|
||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
|
|
||||||
max_inner_outer_dim = 1
|
max_inner_outer_dim = 1
|
||||||
max_input_len = 1
|
max_input_len = 1
|
||||||
|
@ -64,9 +58,9 @@ def init_buffer(cai_linear, use_act_order=False):
|
||||||
max_input_len = 4096
|
max_input_len = 4096
|
||||||
# The temp_state buffer is required to reorder X in the act-order case.
|
# The temp_state buffer is required to reorder X in the act-order case.
|
||||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||||
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim),
|
gptq_temp_state_buffer = torch.zeros(
|
||||||
dtype=torch.float16,
|
(max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
||||||
device=torch.cuda.current_device())
|
)
|
||||||
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
|
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
|
||||||
|
|
||||||
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
|
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
|
||||||
|
@ -77,10 +71,11 @@ def init_buffer(cai_linear, use_act_order=False):
|
||||||
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
@pytest.mark.skipif(
|
||||||
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq")
|
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
||||||
|
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq",
|
||||||
|
)
|
||||||
def test_gptq_linear():
|
def test_gptq_linear():
|
||||||
|
|
||||||
infeature = 1024
|
infeature = 1024
|
||||||
outfeature = 1024
|
outfeature = 1024
|
||||||
group_size = 128
|
group_size = 128
|
||||||
|
@ -120,7 +115,7 @@ def test_gptq_linear():
|
||||||
max_input_len = 2048
|
max_input_len = 2048
|
||||||
buffers = {
|
buffers = {
|
||||||
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
|
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
|
||||||
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device),
|
||||||
}
|
}
|
||||||
|
|
||||||
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
|
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
|
||||||
|
@ -146,5 +141,4 @@ def test_gptq_linear():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
test_gptq_linear()
|
test_gptq_linear()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize
|
from colossalai.testing import clear_cache_before_run, parameterize
|
||||||
|
@ -17,6 +17,7 @@ def check_params_equal(model, torch_model):
|
||||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
assert torch.allclose(p, torch_p, atol=1e-3), f"diff: {torch.abs(p - torch_p)}"
|
assert torch.allclose(p, torch_p, atol=1e-3), f"diff: {torch.abs(p - torch_p)}"
|
||||||
|
|
||||||
|
|
||||||
# TODO Something wrong with ci when running this test.
|
# TODO Something wrong with ci when running this test.
|
||||||
@pytest.mark.skip(reason="skip because of something wrong with CI")
|
@pytest.mark.skip(reason="skip because of something wrong with CI")
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
|
@ -103,9 +103,7 @@ def run_pp(
|
||||||
torch_loss = criterion(torch_output)
|
torch_loss = criterion(torch_output)
|
||||||
torch_loss.backward()
|
torch_loss.backward()
|
||||||
|
|
||||||
pp_ret = schedule.forward_backward_step(
|
pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True)
|
||||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# check loss
|
# check loss
|
||||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
|
|
@ -99,9 +99,7 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||||
torch_output = torch_model(input_list[0])
|
torch_output = torch_model(input_list[0])
|
||||||
torch_loss = criterion(torch_output)
|
torch_loss = criterion(torch_output)
|
||||||
torch_loss.backward()
|
torch_loss.backward()
|
||||||
pp_ret = schedule.forward_backward_step(
|
pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True)
|
||||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# check loss
|
# check loss
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
|
|
Loading…
Reference in New Issue