2023-06-12 07:02:27 +00:00
|
|
|
import time
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import transformers
|
2023-08-24 01:29:25 +00:00
|
|
|
from args import parse_benchmark_args
|
2023-09-07 09:38:45 +00:00
|
|
|
from tqdm import tqdm
|
2023-06-12 07:02:27 +00:00
|
|
|
from transformers import ViTConfig, ViTForImageClassification
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
from colossalai.booster import Booster
|
2023-09-07 09:38:45 +00:00
|
|
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
2023-06-12 07:02:27 +00:00
|
|
|
from colossalai.cluster import DistCoordinator
|
2023-08-24 01:29:25 +00:00
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
2023-06-12 07:02:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
def format_num(num: int, bytes=False):
|
|
|
|
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
|
|
|
|
factor = 1024 if bytes else 1000
|
|
|
|
suffix = "B" if bytes else ""
|
|
|
|
for unit in ["", " K", " M", " G", " T", " P"]:
|
|
|
|
if num < factor:
|
|
|
|
return f"{num:.2f}{unit}{suffix}"
|
|
|
|
num /= factor
|
|
|
|
|
|
|
|
|
2023-09-07 09:38:45 +00:00
|
|
|
def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224):
|
2023-08-24 01:29:25 +00:00
|
|
|
pixel_values = torch.randn(batch_size,
|
|
|
|
num_channels,
|
|
|
|
height,
|
|
|
|
width,
|
|
|
|
device=torch.cuda.current_device(),
|
|
|
|
dtype=torch.float)
|
|
|
|
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
|
2023-09-07 09:38:45 +00:00
|
|
|
return dict(pixel_values=pixel_values, labels=labels)
|
2023-06-12 07:02:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
def colo_memory_cap(size_in_GB):
|
|
|
|
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
|
|
|
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
|
|
|
if size_in_GB * (1024**3) < cuda_capacity:
|
|
|
|
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
|
|
|
print(f"Limiting GPU memory usage to {size_in_GB} GB")
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
|
|
args = parse_benchmark_args()
|
|
|
|
|
|
|
|
# Launch ColossalAI
|
|
|
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
|
|
|
coordinator = DistCoordinator()
|
|
|
|
world_size = coordinator.world_size
|
|
|
|
|
|
|
|
# Manage loggers
|
|
|
|
disable_existing_loggers()
|
|
|
|
logger = get_dist_logger()
|
|
|
|
if coordinator.is_master():
|
|
|
|
transformers.utils.logging.set_verbosity_info()
|
|
|
|
else:
|
|
|
|
transformers.utils.logging.set_verbosity_error()
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-06-12 07:02:27 +00:00
|
|
|
# Whether to set limit on memory capacity
|
|
|
|
if args.mem_cap > 0:
|
|
|
|
colo_memory_cap(args.mem_cap)
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-06-12 07:02:27 +00:00
|
|
|
# Build ViT model
|
|
|
|
config = ViTConfig.from_pretrained(args.model_name_or_path)
|
|
|
|
model = ViTForImageClassification(config)
|
|
|
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
|
|
|
|
|
|
|
# Enable gradient checkpointing
|
2023-09-07 09:38:45 +00:00
|
|
|
if args.grad_checkpoint:
|
|
|
|
model.gradient_checkpointing_enable()
|
2023-06-12 07:02:27 +00:00
|
|
|
|
|
|
|
# Set plugin
|
|
|
|
booster_kwargs = {}
|
|
|
|
if args.plugin == 'torch_ddp_fp16':
|
|
|
|
booster_kwargs['mixed_precision'] = 'fp16'
|
|
|
|
if args.plugin.startswith('torch_ddp'):
|
|
|
|
plugin = TorchDDPPlugin()
|
|
|
|
elif args.plugin == 'gemini':
|
2023-08-24 01:29:25 +00:00
|
|
|
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
2023-06-12 07:02:27 +00:00
|
|
|
elif args.plugin == 'low_level_zero':
|
|
|
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
2023-09-07 09:38:45 +00:00
|
|
|
elif args.plugin == 'hybrid_parallel':
|
|
|
|
plugin = HybridParallelPlugin(tp_size=2,
|
|
|
|
pp_size=2,
|
|
|
|
num_microbatches=None,
|
|
|
|
microbatch_size=1,
|
|
|
|
enable_all_optimization=True,
|
|
|
|
precision='fp16',
|
|
|
|
initial_scale=1)
|
2023-06-12 07:02:27 +00:00
|
|
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
|
|
|
|
|
|
|
# Set optimizer
|
|
|
|
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
|
|
|
|
|
2023-09-07 09:38:45 +00:00
|
|
|
# Set criterion (loss function)
|
|
|
|
def criterion(outputs, inputs):
|
|
|
|
return outputs.loss
|
|
|
|
|
2023-06-12 07:02:27 +00:00
|
|
|
# Set booster
|
|
|
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
2023-09-07 09:38:45 +00:00
|
|
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion)
|
2023-06-12 07:02:27 +00:00
|
|
|
|
|
|
|
# Start training.
|
|
|
|
logger.info(f"Start testing", ranks=[0])
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-06-12 07:02:27 +00:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
model.train()
|
|
|
|
start_time = time.time()
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-09-07 09:38:45 +00:00
|
|
|
with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar:
|
|
|
|
for _ in pbar:
|
|
|
|
optimizer.zero_grad()
|
|
|
|
batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224)
|
|
|
|
|
|
|
|
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
|
|
|
|
# run pipeline forward backward
|
|
|
|
batch = iter([batch])
|
|
|
|
outputs = booster.execute_pipeline(batch,
|
|
|
|
model,
|
|
|
|
criterion,
|
|
|
|
optimizer,
|
|
|
|
return_loss=True,
|
|
|
|
return_outputs=True)
|
|
|
|
else:
|
|
|
|
outputs = model(**batch)
|
|
|
|
loss = criterion(outputs, None)
|
|
|
|
# Backward
|
|
|
|
booster.backward(loss, optimizer)
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
2023-08-24 01:29:25 +00:00
|
|
|
|
|
|
|
# Compute Statistics
|
2023-06-12 07:02:27 +00:00
|
|
|
end_time = time.time()
|
|
|
|
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
|
|
|
|
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
|
2023-08-24 01:29:25 +00:00
|
|
|
|
|
|
|
logger.info(
|
|
|
|
f"Testing finished, "
|
|
|
|
f"batch size per gpu: {args.batch_size}, "
|
|
|
|
f"plugin: {args.plugin}, "
|
|
|
|
f"throughput: {throughput}, "
|
|
|
|
f"maximum memory usage per gpu: {max_mem}.",
|
|
|
|
ranks=[0])
|
2023-06-12 07:02:27 +00:00
|
|
|
|
2023-09-07 09:38:45 +00:00
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
2023-06-12 07:02:27 +00:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|