2023-06-08 03:27:05 +00:00
|
|
|
import time
|
|
|
|
|
|
|
|
import torch
|
2023-08-24 01:29:25 +00:00
|
|
|
import tqdm
|
2023-06-08 03:27:05 +00:00
|
|
|
import transformers
|
2023-08-24 01:29:25 +00:00
|
|
|
from args import parse_benchmark_args
|
2023-06-08 03:27:05 +00:00
|
|
|
from transformers import AutoConfig, OPTForCausalLM
|
|
|
|
from transformers.utils.versions import require_version
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
from colossalai.booster import Booster
|
|
|
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
|
|
|
from colossalai.cluster import DistCoordinator
|
2023-08-24 01:29:25 +00:00
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
2023-06-08 03:27:05 +00:00
|
|
|
|
|
|
|
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
|
|
|
|
|
|
|
|
|
|
|
def format_num(num: int, bytes=False):
|
|
|
|
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
|
|
|
|
factor = 1024 if bytes else 1000
|
|
|
|
suffix = "B" if bytes else ""
|
|
|
|
for unit in ["", " K", " M", " G", " T", " P"]:
|
|
|
|
if num < factor:
|
|
|
|
return f"{num:.2f}{unit}{suffix}"
|
|
|
|
num /= factor
|
|
|
|
|
|
|
|
|
|
|
|
def get_data(batch_size, seq_len, vocab_size):
|
|
|
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
|
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
return input_ids, attention_mask
|
|
|
|
|
|
|
|
|
|
|
|
def colo_memory_cap(size_in_GB):
|
|
|
|
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
2023-09-19 06:20:26 +00:00
|
|
|
|
2023-06-08 03:27:05 +00:00
|
|
|
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
|
|
|
if size_in_GB * (1024**3) < cuda_capacity:
|
|
|
|
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
|
|
|
print(f"Limiting GPU memory usage to {size_in_GB} GB")
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = parse_benchmark_args()
|
|
|
|
|
|
|
|
# Launch ColossalAI
|
|
|
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
|
|
|
coordinator = DistCoordinator()
|
|
|
|
world_size = coordinator.world_size
|
|
|
|
|
|
|
|
# Manage loggers
|
|
|
|
disable_existing_loggers()
|
|
|
|
logger = get_dist_logger()
|
|
|
|
if coordinator.is_master():
|
|
|
|
transformers.utils.logging.set_verbosity_info()
|
|
|
|
else:
|
|
|
|
transformers.utils.logging.set_verbosity_error()
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-06-08 03:27:05 +00:00
|
|
|
# Whether to set limit of memory capacity
|
|
|
|
if args.mem_cap > 0:
|
|
|
|
colo_memory_cap(args.mem_cap)
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-06-08 03:27:05 +00:00
|
|
|
# Build OPT model
|
|
|
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
2023-06-12 07:02:27 +00:00
|
|
|
model = OPTForCausalLM(config=config)
|
2023-06-08 03:27:05 +00:00
|
|
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
|
|
|
|
|
|
|
# Enable gradient checkpointing
|
|
|
|
model.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
# Set plugin
|
|
|
|
booster_kwargs = {}
|
2023-09-19 06:20:26 +00:00
|
|
|
if args.plugin == "torch_ddp_fp16":
|
|
|
|
booster_kwargs["mixed_precision"] = "fp16"
|
|
|
|
if args.plugin.startswith("torch_ddp"):
|
2023-06-08 03:27:05 +00:00
|
|
|
plugin = TorchDDPPlugin()
|
2023-09-19 06:20:26 +00:00
|
|
|
elif args.plugin == "gemini":
|
2023-08-24 01:29:25 +00:00
|
|
|
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
2023-09-19 06:20:26 +00:00
|
|
|
elif args.plugin == "low_level_zero":
|
2023-06-08 03:27:05 +00:00
|
|
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
|
|
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
|
|
|
|
|
|
|
# Set optimizer
|
|
|
|
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)
|
|
|
|
|
|
|
|
# Set booster
|
|
|
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
|
|
|
model, optimizer, _, _, _ = booster.boost(model, optimizer)
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-06-08 03:27:05 +00:00
|
|
|
SEQ_LEN = 1024
|
|
|
|
VOCAB_SIZE = 50257
|
|
|
|
|
|
|
|
# Start training.
|
|
|
|
logger.info(f"Start testing", ranks=[0])
|
|
|
|
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-06-08 03:27:05 +00:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
model.train()
|
|
|
|
start_time = time.time()
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-06-08 03:27:05 +00:00
|
|
|
for _ in range(args.max_train_steps):
|
|
|
|
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False)
|
2023-09-19 06:20:26 +00:00
|
|
|
loss = outputs["loss"]
|
2023-06-08 03:27:05 +00:00
|
|
|
booster.backward(loss, optimizer)
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
progress_bar.update(1)
|
2023-08-24 01:29:25 +00:00
|
|
|
|
|
|
|
# Compute Statistics
|
2023-06-08 03:27:05 +00:00
|
|
|
end_time = time.time()
|
|
|
|
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
|
|
|
|
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
|
2023-08-24 01:29:25 +00:00
|
|
|
|
|
|
|
logger.info(
|
|
|
|
f"Testing finished, "
|
|
|
|
f"batch size per gpu: {args.batch_size}, "
|
|
|
|
f"plugin: {args.plugin}, "
|
|
|
|
f"throughput: {throughput}, "
|
|
|
|
f"maximum memory usage per gpu: {max_mem}.",
|
2023-09-19 06:20:26 +00:00
|
|
|
ranks=[0],
|
|
|
|
)
|
2023-06-08 03:27:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|