mirror of https://github.com/hpcaitech/ColossalAI
feat baichuan2 rmsnorm whose hidden size equals to 5120 (#5611)
parent
e37ee2fb65
commit
ccf72797e3
|
@ -35,7 +35,7 @@ configs = [
|
||||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")],
|
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")],
|
||||||
ylabel="ms",
|
ylabel="ms",
|
||||||
plot_name=f"RMSNorm benchmarking results",
|
plot_name=f"RMSNorm benchmarking results",
|
||||||
args={"HIDDEN_SIZE": 1024},
|
args={"HIDDEN_SIZE": 5120},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -277,6 +277,9 @@ void rms_layernorm(
|
||||||
case 2:
|
case 2:
|
||||||
RMSNORM_LAUNCHER(2, block);
|
RMSNORM_LAUNCHER(2, block);
|
||||||
break;
|
break;
|
||||||
|
case 3:
|
||||||
|
RMSNORM_LAUNCHER(3, block);
|
||||||
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
RMSNORM_LAUNCHER(4, block);
|
RMSNORM_LAUNCHER(4, block);
|
||||||
break;
|
break;
|
||||||
|
@ -321,6 +324,9 @@ void fused_add_rms_layernorm(
|
||||||
case 2:
|
case 2:
|
||||||
FUSED_ADD_RMSNORM_LAUNCHER(2, block);
|
FUSED_ADD_RMSNORM_LAUNCHER(2, block);
|
||||||
break;
|
break;
|
||||||
|
case 3:
|
||||||
|
FUSED_ADD_RMSNORM_LAUNCHER(3, block);
|
||||||
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
FUSED_ADD_RMSNORM_LAUNCHER(4, block);
|
FUSED_ADD_RMSNORM_LAUNCHER(4, block);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -9,7 +9,7 @@ inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("M", [2, 4, 8, 16])
|
@pytest.mark.parametrize("M", [2, 4, 8, 16])
|
||||||
@pytest.mark.parametrize("N", [64, 128, 512])
|
@pytest.mark.parametrize("N", [64, 128, 512, 5120])
|
||||||
def test_rms_layernorm(M: int, N: int):
|
def test_rms_layernorm(M: int, N: int):
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -48,4 +48,4 @@ def test_rms_layernorm(M: int, N: int):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_rms_layernorm(16, 512)
|
test_rms_layernorm(16, 5120)
|
||||||
|
|
Loading…
Reference in New Issue