You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/shardformer/examples/performance_benchmark.py

90 lines
2.8 KiB

"""
Shardformer Benchmark
"""
import torch
import torch.distributed as dist
import transformers
import triton
import colossalai
from colossalai.shardformer import ShardConfig, ShardFormer
def data_gen(batch_size, seq_length):
input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_sequence_classification(batch_size, seq_length):
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen(batch_size, seq_length)
data["labels"] = torch.ones((batch_size), dtype=torch.long)
return data
MODEL_CONFIG = transformers.LlamaConfig(
num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16,
pad_token_id=2,
)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
# vary seq length for fixed head and batch=4
configs = [
triton.testing.Benchmark(
x_names=["N_CTX"],
x_vals=[2**i for i in range(8, 13)],
line_arg="provider",
line_vals=["org_model", "shard_model"],
line_names=["org_model", "shard_model"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"lama_for_sequence_classification-batch-{BATCH}",
args={"BATCH": BATCH, "dtype": torch.float16, "model_func": model_func},
)
]
def train(model, data):
output = model(**data)
loss = output.logits.mean()
loss.backward()
@triton.testing.perf_report(configs)
def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"):
warmup = 10
rep = 100
# prepare data
data = data_gen_for_sequence_classification(BATCH, N_CTX)
data = {k: v.cuda() for k, v in data.items()}
model = model_func().to(device)
model.train()
if provider == "org_model":
fn = lambda: train(model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "shard_model":
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, _ = shard_former.optimize(model)
sharded_model = sharded_model.cuda()
fn = lambda: train(sharded_model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
# start benchmark, command:
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
if __name__ == "__main__":
colossalai.launch_from_torch()
bench_shardformer.run(save_path=".", print_data=dist.get_rank() == 0)