[tutorial] add synthetic dataset for opt (#1924)

pull/1931/head
ver217 2022-11-13 03:26:11 +08:00 committed by GitHub
parent 0486048453
commit b0b7a786b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 163 additions and 94 deletions

View File

@ -39,6 +39,14 @@ bash ./run_clm.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT).
- gpu-num: the number of GPUs to use, default is 1.
It uses `wikitext` dataset.
To use synthetic dataset:
```bash
bash ./run_clm_synthetic.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
```
## Remarkable Performance
On a single GPU, Colossal-AIs automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed.
Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale.

View File

@ -74,6 +74,7 @@ def get_time_stamp():
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument("-s", "--synthetic", action="store_true")
parser.add_argument(
"--dataset_name",
type=str,
@ -231,15 +232,16 @@ def parse_args():
args = parser.parse_args()
# Sanity checks
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if not args.synthetic:
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
@ -255,6 +257,34 @@ def colo_memory_cap(size_in_GB):
print("Using {} GB of GPU memory".format(size_in_GB))
class DummyDataloader:
def __init__(self, length, batch_size, seq_len, vocab_size):
self.length = length
self.batch_size = batch_size
self.seq_len = seq_len
self.vocab_size = vocab_size
def generate(self):
input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device())
attention_mask = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids}
def __iter__(self):
self.step = 0
return self
def __next__(self):
if self.step < self.length:
self.step += 1
return self.generate()
else:
raise StopIteration
def __len__(self):
return self.length
def main():
args = parse_args()
disable_existing_loggers()
@ -292,46 +322,47 @@ def main():
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
logger.info("Start preparing dataset", ranks=[0])
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[:{args.validation_split_percentage}%]",
)
raw_datasets["train"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[{args.validation_split_percentage}%:]",
)
else:
data_files = {}
dataset_args = {}
if args.train_file is not None:
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
extension = args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{args.validation_split_percentage}%]",
**dataset_args,
)
raw_datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{args.validation_split_percentage}%:]",
**dataset_args,
)
if not args.synthetic:
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[:{args.validation_split_percentage}%]",
)
raw_datasets["train"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[{args.validation_split_percentage}%:]",
)
else:
data_files = {}
dataset_args = {}
if args.train_file is not None:
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
extension = args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{args.validation_split_percentage}%]",
**dataset_args,
)
raw_datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{args.validation_split_percentage}%:]",
**dataset_args,
)
logger.info("Dataset is prepared", ranks=[0])
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
@ -399,23 +430,24 @@ def main():
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
if not args.synthetic:
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
def tokenize_function(examples):
return tokenizer(examples[text_column_name])
def tokenize_function(examples):
return tokenizer(examples[text_column_name])
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
)
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
)
if args.block_size is None:
block_size = tokenizer.model_max_length
@ -447,38 +479,44 @@ def main():
result["labels"] = result["input_ids"].copy()
return result
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
# to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
if not args.synthetic:
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
# to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
desc=f"Grouping texts in chunks of {block_size}",
)
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
desc=f"Grouping texts in chunks of {block_size}",
)
train_dataset = lm_datasets["train"]
eval_dataset = lm_datasets["validation"]
train_dataset = lm_datasets["train"]
eval_dataset = lm_datasets["validation"]
# Log a few random samples from the training set:
# for index in random.sample(range(len(train_dataset)), 3):
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# Log a few random samples from the training set:
# for index in random.sample(range(len(train_dataset)), 3):
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# DataLoaders creation:
train_dataloader = get_dataloader(train_dataset,
shuffle=True,
add_sampler=True,
collate_fn=default_data_collator,
batch_size=args.per_device_train_batch_size)
eval_dataloader = DataLoader(eval_dataset,
collate_fn=default_data_collator,
batch_size=args.per_device_eval_batch_size)
# DataLoaders creation:
train_dataloader = get_dataloader(train_dataset,
shuffle=True,
add_sampler=True,
collate_fn=default_data_collator,
batch_size=args.per_device_train_batch_size)
eval_dataloader = DataLoader(eval_dataset,
collate_fn=default_data_collator,
batch_size=args.per_device_eval_batch_size)
else:
train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings,
config.vocab_size)
eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings,
config.vocab_size)
logger.info("Dataloaders have been created", ranks=[0])
# Optimizer
@ -521,9 +559,11 @@ def main():
# Train!
total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA)
num_train_samples = len(train_dataset) if not args.synthetic else 30 * total_batch_size
num_eval_samples = len(eval_dataset) if not args.synthetic else 10 * total_batch_size
logger.info("***** Running training *****", ranks=[0])
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
logger.info(f" Num examples = {num_train_samples}", ranks=[0])
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0])
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
@ -572,7 +612,7 @@ def main():
losses.append(loss)
losses = torch.cat(losses)
losses = losses[:len(eval_dataset)]
losses = losses[:num_eval_samples]
try:
eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss)

View File

@ -0,0 +1,21 @@
set -x
export BS=${1:-16}
export MEMCAP=${2:-0}
export MODEL=${3:-"125m"}
export GPUNUM=${4:-1}
# make directory for logs
mkdir -p ./logs
export MODLE_PATH="facebook/opt-${MODEL}"
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
torchrun \
--nproc_per_node ${GPUNUM} \
--master_port 19198 \
run_clm.py \
-s \
--output_dir $PWD \
--mem_cap ${MEMCAP} \
--model_name_or_path ${MODLE_PATH} \
--per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log