test(model): support fp32 with flash_attn (#223)

* support tf32 with flash

* move autocast to attention

* fix lint

* fix lint

* fix lint

* fix lint

* fix some bugs in model

* modify the convert dtype
pull/222/head
ytxiong 2023-08-24 13:54:44 +08:00 committed by GitHub
parent fd28bcab58
commit eee93b5a68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 5 deletions

View File

@ -120,7 +120,7 @@ model = dict(
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16",
dtype="torch.tf32", # dtype could be in "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,

View File

@ -220,10 +220,8 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
gpc.config.model.dtype = torch.float16
elif gpc.config.model.dtype == "torch.float32":
assert gpc.config.model.use_flash_attn is False, "when using float32, the use_flash_attn must be False"
gpc.config.model.dtype = torch.float32
elif gpc.config.model.dtype == "torch.tf32":
assert gpc.config.model.use_flash_attn is False, "when using tf32, the use_flash_attn must be False"
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
gpc.config.model.dtype = torch.float32

View File

@ -132,7 +132,13 @@ class MHA(nn.Module):
qkv = self.rotary_emb(qkv, **kwargs)
if inference_params is None:
context = self.inner_attn(qkv)
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.float16):
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv).to(x.dtype)
else:
context = self.inner_attn(qkv)
else:
q = qkv[:, :, 0]
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
@ -164,7 +170,14 @@ class MHA(nn.Module):
kwargs.pop("indexes")
if inference_params is None:
context = self.inner_attn(qkv, **kwargs)
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.float16):
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
else:
context = self.inner_attn(qkv, **kwargs)
else:
raise RuntimeError("Not support this right now")