add fused norm (#6038)

pull/6040/head
Tong Li 2024-08-28 17:12:51 +08:00 committed by GitHub
parent 4a68efb7da
commit 0d3a85d04f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 0 deletions

View File

@ -65,6 +65,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
@ -74,6 +75,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
@ -99,6 +101,7 @@ def train(args) -> None:
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=torch.cuda.is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,