mirror of https://github.com/hpcaitech/ColossalAI
[tutorial] add synthetic dataset for opt (#1924)
parent
0486048453
commit
b0b7a786b7
|
@ -39,6 +39,14 @@ bash ./run_clm.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
|
||||||
the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT).
|
the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT).
|
||||||
- gpu-num: the number of GPUs to use, default is 1.
|
- gpu-num: the number of GPUs to use, default is 1.
|
||||||
|
|
||||||
|
It uses `wikitext` dataset.
|
||||||
|
|
||||||
|
To use synthetic dataset:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash ./run_clm_synthetic.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
|
||||||
|
```
|
||||||
|
|
||||||
## Remarkable Performance
|
## Remarkable Performance
|
||||||
On a single GPU, Colossal-AI’s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed.
|
On a single GPU, Colossal-AI’s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed.
|
||||||
Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale.
|
Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale.
|
||||||
|
|
|
@ -74,6 +74,7 @@ def get_time_stamp():
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = colossalai.get_default_parser()
|
parser = colossalai.get_default_parser()
|
||||||
|
parser.add_argument("-s", "--synthetic", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset_name",
|
"--dataset_name",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -231,6 +232,7 @@ def parse_args():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Sanity checks
|
# Sanity checks
|
||||||
|
if not args.synthetic:
|
||||||
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||||
else:
|
else:
|
||||||
|
@ -255,6 +257,34 @@ def colo_memory_cap(size_in_GB):
|
||||||
print("Using {} GB of GPU memory".format(size_in_GB))
|
print("Using {} GB of GPU memory".format(size_in_GB))
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataloader:
|
||||||
|
|
||||||
|
def __init__(self, length, batch_size, seq_len, vocab_size):
|
||||||
|
self.length = length
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device())
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids}
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.step = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.step < self.length:
|
||||||
|
self.step += 1
|
||||||
|
return self.generate()
|
||||||
|
else:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
|
@ -292,6 +322,7 @@ def main():
|
||||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||||
# download the dataset.
|
# download the dataset.
|
||||||
logger.info("Start preparing dataset", ranks=[0])
|
logger.info("Start preparing dataset", ranks=[0])
|
||||||
|
if not args.synthetic:
|
||||||
if args.dataset_name is not None:
|
if args.dataset_name is not None:
|
||||||
# Downloading and loading a dataset from the hub.
|
# Downloading and loading a dataset from the hub.
|
||||||
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
||||||
|
@ -399,6 +430,7 @@ def main():
|
||||||
|
|
||||||
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
||||||
|
|
||||||
|
if not args.synthetic:
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
column_names = raw_datasets["train"].column_names
|
column_names = raw_datasets["train"].column_names
|
||||||
|
@ -447,6 +479,7 @@ def main():
|
||||||
result["labels"] = result["input_ids"].copy()
|
result["labels"] = result["input_ids"].copy()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
if not args.synthetic:
|
||||||
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
||||||
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
||||||
# to preprocess.
|
# to preprocess.
|
||||||
|
@ -479,6 +512,11 @@ def main():
|
||||||
eval_dataloader = DataLoader(eval_dataset,
|
eval_dataloader = DataLoader(eval_dataset,
|
||||||
collate_fn=default_data_collator,
|
collate_fn=default_data_collator,
|
||||||
batch_size=args.per_device_eval_batch_size)
|
batch_size=args.per_device_eval_batch_size)
|
||||||
|
else:
|
||||||
|
train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings,
|
||||||
|
config.vocab_size)
|
||||||
|
eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings,
|
||||||
|
config.vocab_size)
|
||||||
logger.info("Dataloaders have been created", ranks=[0])
|
logger.info("Dataloaders have been created", ranks=[0])
|
||||||
|
|
||||||
# Optimizer
|
# Optimizer
|
||||||
|
@ -521,9 +559,11 @@ def main():
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
num_train_samples = len(train_dataset) if not args.synthetic else 30 * total_batch_size
|
||||||
|
num_eval_samples = len(eval_dataset) if not args.synthetic else 10 * total_batch_size
|
||||||
|
|
||||||
logger.info("***** Running training *****", ranks=[0])
|
logger.info("***** Running training *****", ranks=[0])
|
||||||
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
|
logger.info(f" Num examples = {num_train_samples}", ranks=[0])
|
||||||
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
|
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
|
||||||
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0])
|
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0])
|
||||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
|
||||||
|
@ -572,7 +612,7 @@ def main():
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
|
|
||||||
losses = torch.cat(losses)
|
losses = torch.cat(losses)
|
||||||
losses = losses[:len(eval_dataset)]
|
losses = losses[:num_eval_samples]
|
||||||
try:
|
try:
|
||||||
eval_loss = torch.mean(losses)
|
eval_loss = torch.mean(losses)
|
||||||
perplexity = math.exp(eval_loss)
|
perplexity = math.exp(eval_loss)
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
set -x
|
||||||
|
export BS=${1:-16}
|
||||||
|
export MEMCAP=${2:-0}
|
||||||
|
export MODEL=${3:-"125m"}
|
||||||
|
export GPUNUM=${4:-1}
|
||||||
|
|
||||||
|
# make directory for logs
|
||||||
|
mkdir -p ./logs
|
||||||
|
|
||||||
|
export MODLE_PATH="facebook/opt-${MODEL}"
|
||||||
|
|
||||||
|
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
|
||||||
|
torchrun \
|
||||||
|
--nproc_per_node ${GPUNUM} \
|
||||||
|
--master_port 19198 \
|
||||||
|
run_clm.py \
|
||||||
|
-s \
|
||||||
|
--output_dir $PWD \
|
||||||
|
--mem_cap ${MEMCAP} \
|
||||||
|
--model_name_or_path ${MODLE_PATH} \
|
||||||
|
--per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log
|
Loading…
Reference in New Issue