|
|
|
@ -138,10 +138,15 @@ def main():
|
|
|
|
|
# ==============================
|
|
|
|
|
# Initialize Booster
|
|
|
|
|
# ==============================
|
|
|
|
|
if args.config in MODEL_CONFIGS:
|
|
|
|
|
config = MODEL_CONFIGS[args.config]
|
|
|
|
|
else:
|
|
|
|
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
scheduler_nodes = None
|
|
|
|
|
if args.pp_style == "zbv":
|
|
|
|
|
mem_f = 34 * 32 + 5 * 4 * 16
|
|
|
|
|
mem_w = -32 * 32
|
|
|
|
|
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
|
|
|
|
|
mem_w = -32 * config.hidden_size
|
|
|
|
|
mem_b = -mem_w - mem_f
|
|
|
|
|
scheduler_nodes = PipelineGraph(
|
|
|
|
|
n_stage=args.pp,
|
|
|
|
@ -275,10 +280,6 @@ def main():
|
|
|
|
|
# ==============================
|
|
|
|
|
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
|
|
|
|
|
|
|
|
|
if args.config in MODEL_CONFIGS:
|
|
|
|
|
config = MODEL_CONFIGS[args.config]
|
|
|
|
|
else:
|
|
|
|
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
|
|
|
|
torch.cuda.manual_seed(42)
|
|
|
|
|
dataset = RandomDataset(
|
|
|
|
|
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
|
|
|
|