mirror of https://github.com/hpcaitech/ColossalAI
87 lines
3.2 KiB
Python
87 lines
3.2 KiB
Python
|
"""
|
||
|
Shardformer Benchmark
|
||
|
"""
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
import transformers
|
||
|
import triton
|
||
|
|
||
|
import colossalai
|
||
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||
|
|
||
|
|
||
|
def data_gen(batch_size, seq_length):
|
||
|
input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long)
|
||
|
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
|
||
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||
|
|
||
|
|
||
|
def data_gen_for_sequence_classification(batch_size, seq_length):
|
||
|
# LM data gen
|
||
|
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||
|
data = data_gen(batch_size, seq_length)
|
||
|
data['labels'] = torch.ones((batch_size), dtype=torch.long)
|
||
|
return data
|
||
|
|
||
|
|
||
|
MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
|
||
|
hidden_size=128,
|
||
|
intermediate_size=256,
|
||
|
num_attention_heads=4,
|
||
|
max_position_embeddings=128,
|
||
|
num_labels=16)
|
||
|
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
|
||
|
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
|
||
|
|
||
|
# vary seq length for fixed head and batch=4
|
||
|
configs = [
|
||
|
triton.testing.Benchmark(x_names=['N_CTX'],
|
||
|
x_vals=[2**i for i in range(8, 13)],
|
||
|
line_arg='provider',
|
||
|
line_vals=['org_model', 'shard_model'],
|
||
|
line_names=['org_model', 'shard_model'],
|
||
|
styles=[('red', '-'), ('blue', '-')],
|
||
|
ylabel='ms',
|
||
|
plot_name=f'lama_for_sequence_classification-batch-{BATCH}',
|
||
|
args={
|
||
|
'BATCH': BATCH,
|
||
|
'dtype': torch.float16,
|
||
|
'model_func': model_func
|
||
|
})
|
||
|
]
|
||
|
|
||
|
|
||
|
def train(model, data):
|
||
|
output = model(**data)
|
||
|
loss = output.logits.mean()
|
||
|
loss.backward()
|
||
|
|
||
|
|
||
|
@triton.testing.perf_report(configs)
|
||
|
def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"):
|
||
|
warmup = 10
|
||
|
rep = 100
|
||
|
# prepare data
|
||
|
data = data_gen_for_sequence_classification(BATCH, N_CTX)
|
||
|
data = {k: v.cuda() for k, v in data.items()}
|
||
|
model = model_func().to(device)
|
||
|
model.train()
|
||
|
if provider == "org_model":
|
||
|
fn = lambda: train(model, data)
|
||
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||
|
return ms
|
||
|
if provider == "shard_model":
|
||
|
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
|
||
|
shard_former = ShardFormer(shard_config=shard_config)
|
||
|
sharded_model = shard_former.optimize(model).cuda()
|
||
|
fn = lambda: train(sharded_model, data)
|
||
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||
|
return ms
|
||
|
|
||
|
|
||
|
# start benchmark, command:
|
||
|
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
|
||
|
if __name__ == "__main__":
|
||
|
colossalai.launch_from_torch({})
|
||
|
bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0)
|