""" Shardformer Benchmark """ import torch import torch.distributed as dist import transformers import triton import colossalai from colossalai.shardformer import ShardConfig, ShardFormer def data_gen(batch_size, seq_length): input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long) attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) return dict(input_ids=input_ids, attention_mask=attention_mask) def data_gen_for_sequence_classification(batch_size, seq_length): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen(batch_size, seq_length) data['labels'] = torch.ones((batch_size), dtype=torch.long) return data MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4, max_position_embeddings=128, num_labels=16) BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64 model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG) # vary seq length for fixed head and batch=4 configs = [ triton.testing.Benchmark(x_names=['N_CTX'], x_vals=[2**i for i in range(8, 13)], line_arg='provider', line_vals=['org_model', 'shard_model'], line_names=['org_model', 'shard_model'], styles=[('red', '-'), ('blue', '-')], ylabel='ms', plot_name=f'lama_for_sequence_classification-batch-{BATCH}', args={ 'BATCH': BATCH, 'dtype': torch.float16, 'model_func': model_func }) ] def train(model, data): output = model(**data) loss = output.logits.mean() loss.backward() @triton.testing.perf_report(configs) def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"): warmup = 10 rep = 100 # prepare data data = data_gen_for_sequence_classification(BATCH, N_CTX) data = {k: v.cuda() for k, v in data.items()} model = model_func().to(device) model.train() if provider == "org_model": fn = lambda: train(model, data) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms if provider == "shard_model": shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model).cuda() fn = lambda: train(sharded_model, data) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms # start benchmark, command: # torchrun --standalone --nproc_per_node=2 performance_benchmark.py if __name__ == "__main__": colossalai.launch_from_torch({}) bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0)