mirror of https://github.com/hpcaitech/ColossalAI
[tutorial] add synthetic dataset for opt (#1924)
parent
0486048453
commit
b0b7a786b7
|
@ -39,6 +39,14 @@ bash ./run_clm.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
|
||||||
the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT).
|
the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT).
|
||||||
- gpu-num: the number of GPUs to use, default is 1.
|
- gpu-num: the number of GPUs to use, default is 1.
|
||||||
|
|
||||||
|
It uses `wikitext` dataset.
|
||||||
|
|
||||||
|
To use synthetic dataset:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash ./run_clm_synthetic.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
|
||||||
|
```
|
||||||
|
|
||||||
## Remarkable Performance
|
## Remarkable Performance
|
||||||
On a single GPU, Colossal-AI’s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed.
|
On a single GPU, Colossal-AI’s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed.
|
||||||
Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale.
|
Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale.
|
||||||
|
|
|
@ -74,6 +74,7 @@ def get_time_stamp():
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = colossalai.get_default_parser()
|
parser = colossalai.get_default_parser()
|
||||||
|
parser.add_argument("-s", "--synthetic", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset_name",
|
"--dataset_name",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -231,15 +232,16 @@ def parse_args():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Sanity checks
|
# Sanity checks
|
||||||
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
if not args.synthetic:
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
||||||
else:
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||||
if args.train_file is not None:
|
else:
|
||||||
extension = args.train_file.split(".")[-1]
|
if args.train_file is not None:
|
||||||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
|
extension = args.train_file.split(".")[-1]
|
||||||
if args.validation_file is not None:
|
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
|
||||||
extension = args.validation_file.split(".")[-1]
|
if args.validation_file is not None:
|
||||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
extension = args.validation_file.split(".")[-1]
|
||||||
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
||||||
|
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||||
|
@ -255,6 +257,34 @@ def colo_memory_cap(size_in_GB):
|
||||||
print("Using {} GB of GPU memory".format(size_in_GB))
|
print("Using {} GB of GPU memory".format(size_in_GB))
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataloader:
|
||||||
|
|
||||||
|
def __init__(self, length, batch_size, seq_len, vocab_size):
|
||||||
|
self.length = length
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device())
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids}
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.step = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.step < self.length:
|
||||||
|
self.step += 1
|
||||||
|
return self.generate()
|
||||||
|
else:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
|
@ -292,46 +322,47 @@ def main():
|
||||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||||
# download the dataset.
|
# download the dataset.
|
||||||
logger.info("Start preparing dataset", ranks=[0])
|
logger.info("Start preparing dataset", ranks=[0])
|
||||||
if args.dataset_name is not None:
|
if not args.synthetic:
|
||||||
# Downloading and loading a dataset from the hub.
|
if args.dataset_name is not None:
|
||||||
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
# Downloading and loading a dataset from the hub.
|
||||||
if "validation" not in raw_datasets.keys():
|
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
||||||
raw_datasets["validation"] = load_dataset(
|
if "validation" not in raw_datasets.keys():
|
||||||
args.dataset_name,
|
raw_datasets["validation"] = load_dataset(
|
||||||
args.dataset_config_name,
|
args.dataset_name,
|
||||||
split=f"train[:{args.validation_split_percentage}%]",
|
args.dataset_config_name,
|
||||||
)
|
split=f"train[:{args.validation_split_percentage}%]",
|
||||||
raw_datasets["train"] = load_dataset(
|
)
|
||||||
args.dataset_name,
|
raw_datasets["train"] = load_dataset(
|
||||||
args.dataset_config_name,
|
args.dataset_name,
|
||||||
split=f"train[{args.validation_split_percentage}%:]",
|
args.dataset_config_name,
|
||||||
)
|
split=f"train[{args.validation_split_percentage}%:]",
|
||||||
else:
|
)
|
||||||
data_files = {}
|
else:
|
||||||
dataset_args = {}
|
data_files = {}
|
||||||
if args.train_file is not None:
|
dataset_args = {}
|
||||||
data_files["train"] = args.train_file
|
if args.train_file is not None:
|
||||||
if args.validation_file is not None:
|
data_files["train"] = args.train_file
|
||||||
data_files["validation"] = args.validation_file
|
if args.validation_file is not None:
|
||||||
extension = args.train_file.split(".")[-1]
|
data_files["validation"] = args.validation_file
|
||||||
if extension == "txt":
|
extension = args.train_file.split(".")[-1]
|
||||||
extension = "text"
|
if extension == "txt":
|
||||||
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
|
extension = "text"
|
||||||
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
|
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
|
||||||
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
|
||||||
if "validation" not in raw_datasets.keys():
|
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
||||||
raw_datasets["validation"] = load_dataset(
|
if "validation" not in raw_datasets.keys():
|
||||||
extension,
|
raw_datasets["validation"] = load_dataset(
|
||||||
data_files=data_files,
|
extension,
|
||||||
split=f"train[:{args.validation_split_percentage}%]",
|
data_files=data_files,
|
||||||
**dataset_args,
|
split=f"train[:{args.validation_split_percentage}%]",
|
||||||
)
|
**dataset_args,
|
||||||
raw_datasets["train"] = load_dataset(
|
)
|
||||||
extension,
|
raw_datasets["train"] = load_dataset(
|
||||||
data_files=data_files,
|
extension,
|
||||||
split=f"train[{args.validation_split_percentage}%:]",
|
data_files=data_files,
|
||||||
**dataset_args,
|
split=f"train[{args.validation_split_percentage}%:]",
|
||||||
)
|
**dataset_args,
|
||||||
|
)
|
||||||
logger.info("Dataset is prepared", ranks=[0])
|
logger.info("Dataset is prepared", ranks=[0])
|
||||||
|
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
|
@ -399,23 +430,24 @@ def main():
|
||||||
|
|
||||||
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
if not args.synthetic:
|
||||||
# First we tokenize all the texts.
|
# Preprocessing the datasets.
|
||||||
column_names = raw_datasets["train"].column_names
|
# First we tokenize all the texts.
|
||||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
column_names = raw_datasets["train"].column_names
|
||||||
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||||
|
|
||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
return tokenizer(examples[text_column_name])
|
return tokenizer(examples[text_column_name])
|
||||||
|
|
||||||
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
|
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
|
||||||
tokenized_datasets = raw_datasets.map(
|
tokenized_datasets = raw_datasets.map(
|
||||||
tokenize_function,
|
tokenize_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=args.preprocessing_num_workers,
|
num_proc=args.preprocessing_num_workers,
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
load_from_cache_file=not args.overwrite_cache,
|
load_from_cache_file=not args.overwrite_cache,
|
||||||
desc="Running tokenizer on dataset",
|
desc="Running tokenizer on dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.block_size is None:
|
if args.block_size is None:
|
||||||
block_size = tokenizer.model_max_length
|
block_size = tokenizer.model_max_length
|
||||||
|
@ -447,38 +479,44 @@ def main():
|
||||||
result["labels"] = result["input_ids"].copy()
|
result["labels"] = result["input_ids"].copy()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
if not args.synthetic:
|
||||||
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
||||||
# to preprocess.
|
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
||||||
#
|
# to preprocess.
|
||||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
#
|
||||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||||
|
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||||
|
|
||||||
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
|
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
|
||||||
lm_datasets = tokenized_datasets.map(
|
lm_datasets = tokenized_datasets.map(
|
||||||
group_texts,
|
group_texts,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=args.preprocessing_num_workers,
|
num_proc=args.preprocessing_num_workers,
|
||||||
load_from_cache_file=not args.overwrite_cache,
|
load_from_cache_file=not args.overwrite_cache,
|
||||||
desc=f"Grouping texts in chunks of {block_size}",
|
desc=f"Grouping texts in chunks of {block_size}",
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = lm_datasets["train"]
|
train_dataset = lm_datasets["train"]
|
||||||
eval_dataset = lm_datasets["validation"]
|
eval_dataset = lm_datasets["validation"]
|
||||||
|
|
||||||
# Log a few random samples from the training set:
|
# Log a few random samples from the training set:
|
||||||
# for index in random.sample(range(len(train_dataset)), 3):
|
# for index in random.sample(range(len(train_dataset)), 3):
|
||||||
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||||
|
|
||||||
# DataLoaders creation:
|
# DataLoaders creation:
|
||||||
train_dataloader = get_dataloader(train_dataset,
|
train_dataloader = get_dataloader(train_dataset,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
add_sampler=True,
|
add_sampler=True,
|
||||||
collate_fn=default_data_collator,
|
collate_fn=default_data_collator,
|
||||||
batch_size=args.per_device_train_batch_size)
|
batch_size=args.per_device_train_batch_size)
|
||||||
eval_dataloader = DataLoader(eval_dataset,
|
eval_dataloader = DataLoader(eval_dataset,
|
||||||
collate_fn=default_data_collator,
|
collate_fn=default_data_collator,
|
||||||
batch_size=args.per_device_eval_batch_size)
|
batch_size=args.per_device_eval_batch_size)
|
||||||
|
else:
|
||||||
|
train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings,
|
||||||
|
config.vocab_size)
|
||||||
|
eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings,
|
||||||
|
config.vocab_size)
|
||||||
logger.info("Dataloaders have been created", ranks=[0])
|
logger.info("Dataloaders have been created", ranks=[0])
|
||||||
|
|
||||||
# Optimizer
|
# Optimizer
|
||||||
|
@ -521,9 +559,11 @@ def main():
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
num_train_samples = len(train_dataset) if not args.synthetic else 30 * total_batch_size
|
||||||
|
num_eval_samples = len(eval_dataset) if not args.synthetic else 10 * total_batch_size
|
||||||
|
|
||||||
logger.info("***** Running training *****", ranks=[0])
|
logger.info("***** Running training *****", ranks=[0])
|
||||||
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
|
logger.info(f" Num examples = {num_train_samples}", ranks=[0])
|
||||||
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
|
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
|
||||||
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0])
|
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0])
|
||||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
|
||||||
|
@ -572,7 +612,7 @@ def main():
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
|
|
||||||
losses = torch.cat(losses)
|
losses = torch.cat(losses)
|
||||||
losses = losses[:len(eval_dataset)]
|
losses = losses[:num_eval_samples]
|
||||||
try:
|
try:
|
||||||
eval_loss = torch.mean(losses)
|
eval_loss = torch.mean(losses)
|
||||||
perplexity = math.exp(eval_loss)
|
perplexity = math.exp(eval_loss)
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
set -x
|
||||||
|
export BS=${1:-16}
|
||||||
|
export MEMCAP=${2:-0}
|
||||||
|
export MODEL=${3:-"125m"}
|
||||||
|
export GPUNUM=${4:-1}
|
||||||
|
|
||||||
|
# make directory for logs
|
||||||
|
mkdir -p ./logs
|
||||||
|
|
||||||
|
export MODLE_PATH="facebook/opt-${MODEL}"
|
||||||
|
|
||||||
|
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
|
||||||
|
torchrun \
|
||||||
|
--nproc_per_node ${GPUNUM} \
|
||||||
|
--master_port 19198 \
|
||||||
|
run_clm.py \
|
||||||
|
-s \
|
||||||
|
--output_dir $PWD \
|
||||||
|
--mem_cap ${MEMCAP} \
|
||||||
|
--model_name_or_path ${MODLE_PATH} \
|
||||||
|
--per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log
|
Loading…
Reference in New Issue