fix(configs/7B_sft.py): model dtype float16 to bfloat16 (#302)

pull/309/head
huangting4201 2023-09-11 20:06:22 +08:00 committed by GitHub
parent 8420115b5e
commit e354410bd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -127,7 +127,7 @@ model = dict(
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,