mirror of https://github.com/InternLM/InternLM
test(model): support fp32 with flash_attn (#223)
* support tf32 with flash * move autocast to attention * fix lint * fix lint * fix lint * fix lint * fix some bugs in model * modify the convert dtypepull/222/head
parent
fd28bcab58
commit
eee93b5a68
|
@ -120,7 +120,7 @@ model = dict(
|
|||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.bfloat16",
|
||||
dtype="torch.tf32", # dtype could be in "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32",
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
|
|
|
@ -220,10 +220,8 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
|||
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
|
||||
gpc.config.model.dtype = torch.float16
|
||||
elif gpc.config.model.dtype == "torch.float32":
|
||||
assert gpc.config.model.use_flash_attn is False, "when using float32, the use_flash_attn must be False"
|
||||
gpc.config.model.dtype = torch.float32
|
||||
elif gpc.config.model.dtype == "torch.tf32":
|
||||
assert gpc.config.model.use_flash_attn is False, "when using tf32, the use_flash_attn must be False"
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
gpc.config.model.dtype = torch.float32
|
||||
|
|
|
@ -132,7 +132,13 @@ class MHA(nn.Module):
|
|||
qkv = self.rotary_emb(qkv, **kwargs)
|
||||
|
||||
if inference_params is None:
|
||||
context = self.inner_attn(qkv)
|
||||
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||
qkv = qkv.to(torch.bfloat16)
|
||||
context = self.inner_attn(qkv).to(x.dtype)
|
||||
else:
|
||||
context = self.inner_attn(qkv)
|
||||
else:
|
||||
q = qkv[:, :, 0]
|
||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||
|
@ -164,7 +170,14 @@ class MHA(nn.Module):
|
|||
kwargs.pop("indexes")
|
||||
|
||||
if inference_params is None:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||
qkv = qkv.to(torch.bfloat16)
|
||||
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
|
||||
else:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
|
||||
else:
|
||||
raise RuntimeError("Not support this right now")
|
||||
|
||||
|
|
Loading…
Reference in New Issue