pull/6081/head
flybird11111 2 months ago
parent d50c0a1e0a
commit b9ac0a6808

@ -138,10 +138,15 @@ def main():
# ==============================
# Initialize Booster
# ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
scheduler_nodes = None
if args.pp_style == "zbv":
mem_f = 34 * 32 + 5 * 4 * 16
mem_w = -32 * 32
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
@ -275,10 +280,6 @@ def main():
# ==============================
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size

Loading…
Cancel
Save