|
|
|
import time
|
|
|
|
from copy import deepcopy
|
|
|
|
from functools import partial
|
|
|
|
from typing import Callable, Tuple
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torchvision.models as tm
|
|
|
|
from transformers import GPT2Config, GPT2LMHeadModel
|
|
|
|
|
|
|
|
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
|
|
|
from colossalai.fx import metainfo_trace
|
|
|
|
|
|
|
|
|
|
|
|
def bench(gm: torch.fx.GraphModule,
|
|
|
|
criterion: torch.nn.Module,
|
|
|
|
data_gen: Callable,
|
|
|
|
num_steps: int = 5) -> Tuple[int, int]:
|
|
|
|
"""Benchmarking a given graph module
|
|
|
|
Args:
|
|
|
|
gm (torch.fx.GraphModule): The graph module to benchmark.
|
|
|
|
criterion (torch.nn.Module): Loss function.
|
|
|
|
data_gen (Callable): Data generator.
|
|
|
|
num_steps (int, optional): Number of test steps. Defaults to 5.
|
|
|
|
Returns:
|
|
|
|
Tuple[int, int]: peak memory in MB and step time in MS.
|
|
|
|
"""
|
|
|
|
gm.train()
|
|
|
|
gm.cuda()
|
|
|
|
step_time = float('inf')
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
cached = torch.cuda.max_memory_allocated(device="cuda")
|
|
|
|
try:
|
|
|
|
for _ in range(num_steps):
|
|
|
|
args, label = data_gen()
|
|
|
|
output, loss = None, None
|
|
|
|
|
|
|
|
torch.cuda.synchronize(device="cuda")
|
|
|
|
start = time.time()
|
|
|
|
output = gm(*args)
|
|
|
|
loss = criterion(output, label)
|
|
|
|
loss.backward()
|
|
|
|
torch.cuda.synchronize(device="cuda")
|
|
|
|
step_time = min(step_time, time.time() - start)
|
|
|
|
|
|
|
|
for child in gm.children():
|
|
|
|
for param in child.parameters():
|
|
|
|
param.grad = None
|
|
|
|
del args, label, output, loss
|
|
|
|
except:
|
|
|
|
del args, label, output, loss
|
|
|
|
gm.to("cpu")
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2
|
|
|
|
return peak_mem, step_time * 1.0e3
|
|
|
|
|
|
|
|
|
|
|
|
def bench_rotor(gm: torch.fx.GraphModule,
|
|
|
|
criterion: torch.nn.Module,
|
|
|
|
data_gen: Callable,
|
|
|
|
num_steps: int = 5,
|
|
|
|
sample_points: int = 20,
|
|
|
|
free_memory: int = torch.cuda.mem_get_info()[0],
|
|
|
|
start_factor: int = 4) -> Tuple[np.array, list, list]:
|
|
|
|
"""Auto Checkpoint Rotor Algorithm benchmarking
|
|
|
|
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
|
|
|
|
Args:
|
|
|
|
gm (torch.fx.GraphModule): The graph module to benchmark.
|
|
|
|
criterion (torch.nn.Module): Loss function.
|
|
|
|
data_gen (Callable): Data generator.
|
|
|
|
num_steps (int, optional): Number of test steps. Defaults to 5.
|
|
|
|
sample_points (int, optional): Number of sample points. Defaults to 20.
|
|
|
|
free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0].
|
|
|
|
start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget
|
|
|
|
will be free_memory / start_factor. Defaults to 4.
|
|
|
|
Returns:
|
|
|
|
Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS).
|
|
|
|
"""
|
|
|
|
peak_hist, step_hist = [], []
|
|
|
|
raw_graph = deepcopy(gm.graph)
|
|
|
|
for budget in np.linspace(free_memory // start_factor, free_memory, sample_points):
|
|
|
|
gm = metainfo_trace(gm, *data_gen()[0])
|
|
|
|
solver = CheckpointSolverRotor(gm.graph, free_memory=budget)
|
|
|
|
try:
|
|
|
|
gm.graph = solver.solve(verbose=False)
|
|
|
|
peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps)
|
|
|
|
except:
|
|
|
|
peak_memory, step_time = budget / 1024**2, float('inf')
|
|
|
|
peak_hist.append(peak_memory)
|
|
|
|
step_hist.append(step_time)
|
|
|
|
gm.graph = deepcopy(raw_graph)
|
|
|
|
return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist
|
|
|
|
|
|
|
|
|
|
|
|
class GPTLMModel(nn.Module):
|
|
|
|
"""
|
|
|
|
GPT Model
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
hidden_size=768,
|
|
|
|
num_layers=12,
|
|
|
|
num_attention_heads=12,
|
|
|
|
max_seq_len=1024,
|
|
|
|
vocab_size=50257,
|
|
|
|
checkpoint=False):
|
|
|
|
super().__init__()
|
|
|
|
self.checkpoint = checkpoint
|
|
|
|
self.model = GPT2LMHeadModel(
|
|
|
|
GPT2Config(n_embd=hidden_size,
|
|
|
|
n_layer=num_layers,
|
|
|
|
n_head=num_attention_heads,
|
|
|
|
n_positions=max_seq_len,
|
|
|
|
n_ctx=max_seq_len,
|
|
|
|
vocab_size=vocab_size))
|
|
|
|
if checkpoint:
|
|
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
|
|
# Only return lm_logits
|
|
|
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
|
|
|
|
|
|
|
|
|
|
|
class GPTLMLoss(nn.Module):
|
|
|
|
"""
|
|
|
|
GPT Loss
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.loss_fn = nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
def forward(self, logits, labels):
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
# Flatten the tokens
|
|
|
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
|
|
|
|
|
|
|
|
def gpt2_medium(checkpoint=False):
|
|
|
|
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
def gpt2_xl(checkpoint=False):
|
|
|
|
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
def gpt2_6b(checkpoint=False):
|
|
|
|
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'):
|
|
|
|
"""
|
|
|
|
Generate random data for gpt2 benchmarking
|
|
|
|
"""
|
|
|
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
|
|
|
attention_mask = torch.ones_like(input_ids, device=device)
|
|
|
|
return (input_ids, attention_mask), attention_mask
|
|
|
|
|
|
|
|
|
|
|
|
def data_gen_resnet(batch_size, shape, device='cuda:0'):
|
|
|
|
"""
|
|
|
|
Generate random data for resnet benchmarking
|
|
|
|
"""
|
|
|
|
data = torch.empty(batch_size, *shape, device=device)
|
|
|
|
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
|
|
|
|
return (data,), label
|