[feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
colossalchat
Tong Li 2024-08-02 10:06:25 +08:00 committed by GitHub
parent 62cdac6b7b
commit 19d1510ea2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 93 additions and 77 deletions

View File

@ -197,9 +197,7 @@ class AGIEvalDataset(BaseDataset):
""" """
@staticmethod @staticmethod
def load( def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
dataset = {"test": {}} dataset = {"test": {}}
files = glob.glob(os.path.join(path, "*.jsonl")) files = glob.glob(os.path.join(path, "*.jsonl"))

View File

@ -1,6 +1,9 @@
from abc import abstractstaticmethod from abc import abstractstaticmethod
from colossal_eval.utils import jdump from colossal_eval.utils import jdump
from torch.utils.data import Dataset
from colossalai.logging import DistributedLogger
class BaseDataset: class BaseDataset:
@ -12,13 +15,24 @@ class BaseDataset:
logger: Logger for the dataset. logger: Logger for the dataset.
""" """
def __init__(self, path, logger, few_shot, forward_only=False, load_train=False, load_reference=False): def __init__(self, path, logger, *args, **kwargs):
self.dataset = self.load(path, logger, few_shot, forward_only, load_train, load_reference) self.dataset = self.load(path, logger, *args, **kwargs)
def save(self, save_path): def save(self, save_path):
"""Save the converted dataset""" """Save the converted dataset"""
jdump(self.dataset, save_path) jdump(self.dataset, save_path)
@abstractstaticmethod @abstractstaticmethod
def load(path, logger): def load(path, logger: DistributedLogger, *args, **kwargs):
"""Load the original dataset and convert it into the inference dataset""" """Load the original dataset and convert it into the inference dataset"""
class DistributedDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]

View File

@ -90,9 +90,7 @@ class CEvalDataset(BaseDataset):
""" """
@staticmethod @staticmethod
def load( def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
dataset = {"dev": {}, "test": {}} dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]: for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split)) files = os.listdir(os.path.join(path, split))

View File

@ -101,9 +101,7 @@ class CMMLUDataset(BaseDataset):
""" """
@staticmethod @staticmethod
def load( def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
dataset = {"dev": {}, "test": {}} dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]: for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split)) files = os.listdir(os.path.join(path, split))

View File

@ -37,7 +37,7 @@ class ColossalDataset(BaseDataset):
""" """
@staticmethod @staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}} dataset = {"test": {}}
data = jload(path) data = jload(path)
data_per_category = get_data_per_category(data) data_per_category = get_data_per_category(data)

View File

@ -28,7 +28,7 @@ class CValuesDataset(BaseDataset):
""" """
@staticmethod @staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}} dataset = {"test": {}}
file_path = os.path.join(path, "cvalues_responsibility_mc.jsonl") file_path = os.path.join(path, "cvalues_responsibility_mc.jsonl")
data_list = [] data_list = []

View File

@ -69,9 +69,7 @@ class GaoKaoBenchDataset(BaseDataset):
""" """
@staticmethod @staticmethod
def load( def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
dataset = {"test": {}} dataset = {"test": {}}
for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]: for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]:
files = os.listdir(os.path.join(path, "data", category)) files = os.listdir(os.path.join(path, "data", category))

View File

@ -77,7 +77,7 @@ class LongBenchDataset(BaseDataset):
""" """
@staticmethod @staticmethod
def load(path: str, logger: DistributedLogger) -> List[Dict]: def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": {}} dataset = {"test": {}}
files = os.listdir(path) files = os.listdir(path)

View File

@ -31,9 +31,7 @@ class MMLUDataset(BaseDataset):
""" """
@staticmethod @staticmethod
def load( def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool
) -> List[Dict]:
dataset = {"dev": {}, "test": {}} dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]: for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split)) files = os.listdir(os.path.join(path, split))

View File

@ -27,12 +27,12 @@ class MTBenchDataset(BaseDataset):
This dataset class will convert the original dataset into the inference dataset. This dataset class will convert the original dataset into the inference dataset.
""" """
def __init__(self, path, logger, few_shot): def __init__(self, path, logger: DistributedLogger, *args, **kwargs):
self.multiturn = True self.multiturn = True
self.dataset = self.load(path, logger, few_shot) self.dataset = self.load(path, logger, *args, **kwargs)
@staticmethod @staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]:
dataset = {"test": defaultdict(dict)} dataset = {"test": defaultdict(dict)}
file_path = os.path.join(path, "question.jsonl") file_path = os.path.join(path, "question.jsonl")

View File

@ -130,7 +130,7 @@ class SafetyBenchENDataset(BaseDataset):
""" """
@staticmethod @staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}} dataset = {"dev": {}, "test": {}}
data_files = [os.path.join(path, file_name) for file_name in FILES] data_files = [os.path.join(path, file_name) for file_name in FILES]
for file_path in data_files: for file_path in data_files:

View File

@ -130,7 +130,7 @@ class SafetyBenchZHDataset(BaseDataset):
""" """
@staticmethod @staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]:
dataset = {"dev": {}, "test": {}} dataset = {"dev": {}, "test": {}}
data_files = [os.path.join(path, file_name) for file_name in FILES] data_files = [os.path.join(path, file_name) for file_name in FILES]
for file_path in data_files: for file_path in data_files:

View File

@ -1,11 +1,11 @@
import copy import copy
import math
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0 from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0
from peft import PeftModel from peft import PeftModel
from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
@ -130,7 +130,7 @@ class HuggingFaceModel(BaseModel):
if shard_config is not None: if shard_config is not None:
self.model = AutoModel.from_pretrained(path, **model_kwargs) self.model = AutoModel.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config) shard_former = ShardFormer(shard_config)
self.model, sharded_parameters = shard_former.optimize(self.model) self.model, _ = shard_former.optimize(self.model)
self.model.to(get_current_device()) self.model.to(get_current_device())
if peft_path is not None: if peft_path is not None:
@ -325,7 +325,7 @@ class HuggingFaceModel(BaseModel):
return input_ids_list, labels_list, None return input_ids_list, labels_list, None
def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
""" """
Infer the given data. Infer the given data.
This function will call self.generate() to get model outputs and also self.model() to get logits. This function will call self.generate() to get model outputs and also self.model() to get logits.
@ -359,26 +359,23 @@ class HuggingFaceModel(BaseModel):
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1
turn_desc = "" if turn == 0 else f"-turn{turn}"
bar = tqdm( bar = tqdm(
range(math.ceil(len(data) / self.batch_size)), range(len(data_loader)),
desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps", desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps",
disable=not is_rank_0(), disable=not is_rank_0(),
) )
loss_fct = torch.nn.CrossEntropyLoss(reduction="none") loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
answers = copy.deepcopy(data) answers = []
for i in range(0, len(data), self.batch_size):
batch = data[i : i + self.batch_size] for i, batch in enumerate(data_loader):
batch_prompt, batch_target = get_batch_prompt( batch_prompt, batch_target = get_batch_prompt(
self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length
) )
if is_rank_0() and debug and i == 0: if is_rank_0() and debug and i == 0:
self.logger.info( self.logger.info(
f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}" f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}"
) )
self.logger.info("-" * 120) self.logger.info("-" * 120)
self.logger.info("An example prompt and prompt with target is:") self.logger.info("An example prompt and prompt with target is:")
@ -402,7 +399,7 @@ class HuggingFaceModel(BaseModel):
# Otherwise this will violate the single-choice setting. # Otherwise this will violate the single-choice setting.
if calculate_loss: if calculate_loss:
labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))] labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))]
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist() loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
@ -411,29 +408,30 @@ class HuggingFaceModel(BaseModel):
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs)) {choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
] ]
for j in range(len(batch_prompt)): for j in range(len(batch)):
if not pretrain: if not pretrain:
if isinstance(answers[i + j]["output"], list): if isinstance(batch[j]["output"], list):
answers[i + j]["output"].append(batch_decodes[j].strip()) batch[j]["output"].append(batch_decodes[j].strip())
else: else:
answers[i + j]["output"] = batch_decodes[j].strip() batch[j]["output"] = batch_decodes[j].strip()
if isinstance(scores, torch.Tensor): if isinstance(scores, torch.Tensor):
answers[i + j]["logits_over_choices"] = probs[j] batch[j]["logits_over_choices"] = probs[j]
if calculate_loss: if calculate_loss:
answers[i + j]["loss_over_choices"] = loss_over_choices[j] batch[j]["loss_over_choices"] = loss_over_choices[j]
if calculate_loss: if calculate_loss:
answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
# loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity. # loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
# However, loss (which is per sample loss) suffices for most cases. # However, loss (which is per sample loss) suffices for most cases.
answers[i + j]["loss_sum"] = batch_losses[j] batch[j]["loss_sum"] = batch_losses[j]
answers[i + j]["token_num"] = batch_target_token_nums[j] batch[j]["token_num"] = batch_target_token_nums[j]
if batch_bytes_nums: if batch_bytes_nums:
answers[i + j]["byte_num"] = batch_bytes_nums[j] batch[j]["byte_num"] = batch_bytes_nums[j]
answers.extend(batch)
bar.update() bar.update()
@ -600,7 +598,7 @@ class HuggingFaceCausalLM(HuggingFaceModel):
if shard_config is not None: if shard_config is not None:
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config) shard_former = ShardFormer(shard_config)
self.model, sharded_parameters = shard_former.optimize(self.model) self.model, _ = shard_former.optimize(self.model)
self.model.to(get_current_device()) self.model.to(get_current_device())
if peft_path is not None: if peft_path is not None:

View File

@ -123,15 +123,13 @@ class Conversation:
} }
def get_few_shot_prefix( def get_few_shot_prefix(few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], max_tokens: int) -> str:
conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int
) -> str:
""" """
Get few shot prefix. Get few shot prefix.
Args: Args:
conv: Conversation template. few_shot_data: Few shot examples to generate few shot prompt prefix.
few_shot_examples: Few shot examples to generate few shot prompt prefix. tokenizer: tokenizer used to tokenize data.
Returns: Returns:
Few shot prompt prefix. Few shot prompt prefix.
@ -157,7 +155,6 @@ def get_batch_prompt(
batch: List[Dict], batch: List[Dict],
few_shot_data: List[str], few_shot_data: List[str],
tokenizer: Optional[AutoTokenizer], tokenizer: Optional[AutoTokenizer],
language: Optional[str],
model_max_length: Optional[int], model_max_length: Optional[int],
) -> Tuple[List[Dict], List[Dict]]: ) -> Tuple[List[Dict], List[Dict]]:
""" """
@ -167,6 +164,7 @@ def get_batch_prompt(
conv: Conversation template. conv: Conversation template.
batch: Batch data to generate prompt from. batch: Batch data to generate prompt from.
few_shot_data: Few shot data to generate few shot prompt prefix. few_shot_data: Few shot data to generate few shot prompt prefix.
tokenizer: tokenizer used to tokenize data.
Returns: Returns:
Tuple containg batch prompt and target. Tuple containg batch prompt and target.
@ -192,7 +190,7 @@ def get_batch_prompt(
else: else:
raise Exception("When using few-shot, target answer should be a string.") raise Exception("When using few-shot, target answer should be a string.")
few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens) few_shot_prefix = get_few_shot_prefix(few_shot_data, tokenizer, max_tokens)
conv.append_message(conv.roles[0], few_shot_prefix + query_text) conv.append_message(conv.roles[0], few_shot_prefix + query_text)
conv.append_message(conv.roles[1], None) conv.append_message(conv.roles[1], None)

View File

@ -5,6 +5,8 @@ from typing import Dict, List
import torch.distributed as dist import torch.distributed as dist
from colossal_eval import dataset, models, utils from colossal_eval import dataset, models, utils
from colossal_eval.dataset.base import DistributedDataset
from torch.utils.data import DataLoader, DistributedSampler
import colossalai import colossalai
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
@ -13,6 +15,7 @@ from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
logger = get_dist_logger() logger = get_dist_logger()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def rm_and_merge( def rm_and_merge(
@ -54,7 +57,8 @@ def rm_and_merge(
) )
else: else:
rank_answers = utils.jload(directory) rank_answers = utils.jload(directory)
answers["data"].extend(rank_answers["data"]) deduplidate_answers = [x for x in rank_answers["data"] if x not in answers["data"]]
answers["data"].extend(deduplidate_answers)
answers["inference_kwargs"] = rank_answers["inference_kwargs"] answers["inference_kwargs"] = rank_answers["inference_kwargs"]
for r in range(dp_size): for r in range(dp_size):
@ -65,7 +69,7 @@ def rm_and_merge(
os.remove(directory) os.remove(directory)
except Exception as e: except Exception as e:
print(e) print(e)
print(len(answers["data"]))
all_answers[category] = answers all_answers[category] = answers
all_answers_with_dataset_class["inference_results"] = all_answers all_answers_with_dataset_class["inference_results"] = all_answers
@ -108,7 +112,12 @@ def main(args):
tp_rank = coordinates[TP_AXIS] tp_rank = coordinates[TP_AXIS]
shard_config = ( shard_config = (
ShardConfig(tensor_parallel_process_group=tp_group, enable_tensor_parallelism=args.tp_size > 1) ShardConfig(
tensor_parallel_process_group=tp_group,
enable_tensor_parallelism=args.tp_size > 1,
parallel_output=False,
enable_all_optimization=True,
)
if args.tp_size > 1 if args.tp_size > 1
else None else None
) )
@ -183,6 +192,7 @@ def main(args):
model_name = model_parameter["name"] model_name = model_parameter["name"]
model_class = eval(f"models.{model_parameter['model_class']}") model_class = eval(f"models.{model_parameter['model_class']}")
paramerters = model_parameter["parameters"] paramerters = model_parameter["parameters"]
batch_size = paramerters["batch_size"]
paramerters.update({"logger": logger}) paramerters.update({"logger": logger})
paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]}) paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
paramerters.update({"shard_config": shard_config}) paramerters.update({"shard_config": shard_config})
@ -192,7 +202,6 @@ def main(args):
raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.") raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")
for dataset_name, split_data in inference_data.items(): for dataset_name, split_data in inference_data.items():
start = 0
prev_questions = None prev_questions = None
for category, category_data in split_data.items(): for category, category_data in split_data.items():
num_turn = category_data["inference_kwargs"].get("turns", 1) num_turn = category_data["inference_kwargs"].get("turns", 1)
@ -201,26 +210,33 @@ def main(args):
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!") raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
answers_to_dump = copy.deepcopy(category_data) answers_to_dump = copy.deepcopy(category_data)
partition_size = len(category_data["data"]) // dp_size
redundant = len(category_data["data"]) % dp_size
# Ensure that the amount of data for inference is as consistent as possible across different processes.
lengths = [partition_size for _ in range(dp_size)]
for j in range(redundant):
lengths[(j + start) % dp_size] += 1
start = (start + redundant) % dp_size
for turn in range(num_turn): for turn in range(num_turn):
if turn == 0: if turn == 0:
questions = category_data["data"][ dist_dataset = DistributedDataset(category_data["data"])
sum(lengths[0:dp_rank]) : sum(lengths[0:dp_rank]) + lengths[dp_rank]
]
else: else:
questions = prev_questions dist_dataset = DistributedDataset(prev_questions)
sampler = DistributedSampler(
dist_dataset,
num_replicas=pg_mesh.size(DP_AXIS),
rank=pg_mesh.coordinate(DP_AXIS),
shuffle=False,
)
questions_loader = DataLoader(
dist_dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=8,
pin_memory=True,
collate_fn=lambda x: x,
)
category_data["inference_kwargs"]["dataset"] = dataset_name
category_data["inference_kwargs"]["category"] = category
answers_per_rank = model_.inference( answers_per_rank = model_.inference(
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] data_loader=questions_loader,
inference_kwargs=category_data["inference_kwargs"],
debug=debug_args[dataset_name],
) )
prev_questions = answers_per_rank prev_questions = answers_per_rank