mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
170 lines
6.1 KiB
170 lines
6.1 KiB
import time |
|
from copy import deepcopy |
|
from functools import partial |
|
from typing import Callable, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.models as tm |
|
from transformers import GPT2Config, GPT2LMHeadModel |
|
|
|
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor |
|
from colossalai.fx import metainfo_trace |
|
|
|
|
|
def bench(gm: torch.fx.GraphModule, |
|
criterion: torch.nn.Module, |
|
data_gen: Callable, |
|
num_steps: int = 5) -> Tuple[int, int]: |
|
"""Benchmarking a given graph module |
|
Args: |
|
gm (torch.fx.GraphModule): The graph module to benchmark. |
|
criterion (torch.nn.Module): Loss function. |
|
data_gen (Callable): Data generator. |
|
num_steps (int, optional): Number of test steps. Defaults to 5. |
|
Returns: |
|
Tuple[int, int]: peak memory in MB and step time in MS. |
|
""" |
|
gm.train() |
|
gm.cuda() |
|
step_time = float('inf') |
|
torch.cuda.synchronize() |
|
torch.cuda.empty_cache() |
|
torch.cuda.reset_peak_memory_stats() |
|
cached = torch.cuda.max_memory_allocated(device="cuda") |
|
try: |
|
for _ in range(num_steps): |
|
args, label = data_gen() |
|
output, loss = None, None |
|
|
|
torch.cuda.synchronize(device="cuda") |
|
start = time.time() |
|
output = gm(*args) |
|
loss = criterion(output, label) |
|
loss.backward() |
|
torch.cuda.synchronize(device="cuda") |
|
step_time = min(step_time, time.time() - start) |
|
|
|
for child in gm.children(): |
|
for param in child.parameters(): |
|
param.grad = None |
|
del args, label, output, loss |
|
except: |
|
del args, label, output, loss |
|
gm.to("cpu") |
|
torch.cuda.empty_cache() |
|
peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2 |
|
return peak_mem, step_time * 1.0e3 |
|
|
|
|
|
def bench_rotor(gm: torch.fx.GraphModule, |
|
criterion: torch.nn.Module, |
|
data_gen: Callable, |
|
num_steps: int = 5, |
|
sample_points: int = 20, |
|
free_memory: int = torch.cuda.mem_get_info()[0], |
|
start_factor: int = 4) -> Tuple[np.array, list, list]: |
|
"""Auto Checkpoint Rotor Algorithm benchmarking |
|
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data. |
|
Args: |
|
gm (torch.fx.GraphModule): The graph module to benchmark. |
|
criterion (torch.nn.Module): Loss function. |
|
data_gen (Callable): Data generator. |
|
num_steps (int, optional): Number of test steps. Defaults to 5. |
|
sample_points (int, optional): Number of sample points. Defaults to 20. |
|
free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0]. |
|
start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget |
|
will be free_memory / start_factor. Defaults to 4. |
|
Returns: |
|
Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS). |
|
""" |
|
peak_hist, step_hist = [], [] |
|
raw_graph = deepcopy(gm.graph) |
|
for budget in np.linspace(free_memory // start_factor, free_memory, sample_points): |
|
gm = metainfo_trace(gm, *data_gen()[0]) |
|
solver = CheckpointSolverRotor(gm.graph, free_memory=budget) |
|
try: |
|
gm.graph = solver.solve(verbose=False) |
|
peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps) |
|
except: |
|
peak_memory, step_time = budget / 1024**2, float('inf') |
|
peak_hist.append(peak_memory) |
|
step_hist.append(step_time) |
|
gm.graph = deepcopy(raw_graph) |
|
return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist |
|
|
|
|
|
class GPTLMModel(nn.Module): |
|
""" |
|
GPT Model |
|
""" |
|
|
|
def __init__(self, |
|
hidden_size=768, |
|
num_layers=12, |
|
num_attention_heads=12, |
|
max_seq_len=1024, |
|
vocab_size=50257, |
|
checkpoint=False): |
|
super().__init__() |
|
self.checkpoint = checkpoint |
|
self.model = GPT2LMHeadModel( |
|
GPT2Config(n_embd=hidden_size, |
|
n_layer=num_layers, |
|
n_head=num_attention_heads, |
|
n_positions=max_seq_len, |
|
n_ctx=max_seq_len, |
|
vocab_size=vocab_size)) |
|
if checkpoint: |
|
self.model.gradient_checkpointing_enable() |
|
|
|
def forward(self, input_ids, attention_mask): |
|
# Only return lm_logits |
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] |
|
|
|
|
|
class GPTLMLoss(nn.Module): |
|
""" |
|
GPT Loss |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.loss_fn = nn.CrossEntropyLoss() |
|
|
|
def forward(self, logits, labels): |
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
# Flatten the tokens |
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
|
|
def gpt2_medium(checkpoint=False): |
|
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) |
|
|
|
|
|
def gpt2_xl(checkpoint=False): |
|
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) |
|
|
|
|
|
def gpt2_6b(checkpoint=False): |
|
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) |
|
|
|
|
|
def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'): |
|
""" |
|
Generate random data for gpt2 benchmarking |
|
""" |
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) |
|
attention_mask = torch.ones_like(input_ids, device=device) |
|
return (input_ids, attention_mask), attention_mask |
|
|
|
|
|
def data_gen_resnet(batch_size, shape, device='cuda:0'): |
|
""" |
|
Generate random data for resnet benchmarking |
|
""" |
|
data = torch.empty(batch_size, *shape, device=device) |
|
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) |
|
return (data,), label
|
|
|