mirror of https://github.com/hpcaitech/ColossalAI
add fused norm (#6038)
parent
4a68efb7da
commit
0d3a85d04f
|
@ -65,6 +65,7 @@ def train(args) -> None:
|
|||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
|
@ -74,6 +75,7 @@ def train(args) -> None:
|
|||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
|
@ -99,6 +101,7 @@ def train(args) -> None:
|
|||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
|
|
Loading…
Reference in New Issue