diff --git a/configs/7B_sft.py b/configs/7B_sft.py index e0b9a8a..7f44533 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -128,7 +128,7 @@ model = dict( num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, apply_post_layer_norm=False, - dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, use_flash_attn=True,