diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 0c81befcf..5c0fc15c5 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -138,10 +138,15 @@ def main(): # ============================== # Initialize Booster # ============================== + if args.config in MODEL_CONFIGS: + config = MODEL_CONFIGS[args.config] + else: + config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) + scheduler_nodes = None if args.pp_style == "zbv": - mem_f = 34 * 32 + 5 * 4 * 16 - mem_w = -32 * 32 + mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length + mem_w = -32 * config.hidden_size mem_b = -mem_w - mem_f scheduler_nodes = PipelineGraph( n_stage=args.pp, @@ -275,10 +280,6 @@ def main(): # ============================== dp_size = getattr(plugin, "dp_size", coordinator.world_size) - if args.config in MODEL_CONFIGS: - config = MODEL_CONFIGS[args.config] - else: - config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) torch.cuda.manual_seed(42) dataset = RandomDataset( num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size