feat baichuan2 rmsnorm whose hidden size equals to 5120 (#5611)

pull/5623/head
Steve Luo 2024-04-19 15:34:53 +08:00 committed by GitHub
parent e37ee2fb65
commit ccf72797e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 3 deletions

View File

@ -35,7 +35,7 @@ configs = [
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")],
ylabel="ms",
plot_name=f"RMSNorm benchmarking results",
args={"HIDDEN_SIZE": 1024},
args={"HIDDEN_SIZE": 5120},
)
]

View File

@ -277,6 +277,9 @@ void rms_layernorm(
case 2:
RMSNORM_LAUNCHER(2, block);
break;
case 3:
RMSNORM_LAUNCHER(3, block);
break;
case 4:
RMSNORM_LAUNCHER(4, block);
break;
@ -321,6 +324,9 @@ void fused_add_rms_layernorm(
case 2:
FUSED_ADD_RMSNORM_LAUNCHER(2, block);
break;
case 3:
FUSED_ADD_RMSNORM_LAUNCHER(3, block);
break;
case 4:
FUSED_ADD_RMSNORM_LAUNCHER(4, block);
break;

View File

@ -9,7 +9,7 @@ inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("M", [2, 4, 8, 16])
@pytest.mark.parametrize("N", [64, 128, 512])
@pytest.mark.parametrize("N", [64, 128, 512, 5120])
def test_rms_layernorm(M: int, N: int):
torch.manual_seed(123)
torch.cuda.empty_cache()
@ -48,4 +48,4 @@ def test_rms_layernorm(M: int, N: int):
if __name__ == "__main__":
test_rms_layernorm(16, 512)
test_rms_layernorm(16, 5120)