mirror of https://github.com/InternLM/InternLM
test(model): support fp32 with flash_attn (#223)
* support tf32 with flash * move autocast to attention * fix lint * fix lint * fix lint * fix lint * fix some bugs in model * modify the convert dtypepull/222/head
parent
fd28bcab58
commit
eee93b5a68
|
@ -120,7 +120,7 @@ model = dict(
|
||||||
num_layers=NUM_LAYER,
|
num_layers=NUM_LAYER,
|
||||||
mlp_ratio=MLP_RATIO,
|
mlp_ratio=MLP_RATIO,
|
||||||
apply_post_layer_norm=False,
|
apply_post_layer_norm=False,
|
||||||
dtype="torch.bfloat16",
|
dtype="torch.tf32", # dtype could be in "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32",
|
||||||
norm_type="rmsnorm",
|
norm_type="rmsnorm",
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
use_flash_attn=True,
|
use_flash_attn=True,
|
||||||
|
|
|
@ -220,10 +220,8 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
||||||
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
|
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
|
||||||
gpc.config.model.dtype = torch.float16
|
gpc.config.model.dtype = torch.float16
|
||||||
elif gpc.config.model.dtype == "torch.float32":
|
elif gpc.config.model.dtype == "torch.float32":
|
||||||
assert gpc.config.model.use_flash_attn is False, "when using float32, the use_flash_attn must be False"
|
|
||||||
gpc.config.model.dtype = torch.float32
|
gpc.config.model.dtype = torch.float32
|
||||||
elif gpc.config.model.dtype == "torch.tf32":
|
elif gpc.config.model.dtype == "torch.tf32":
|
||||||
assert gpc.config.model.use_flash_attn is False, "when using tf32, the use_flash_attn must be False"
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
gpc.config.model.dtype = torch.float32
|
gpc.config.model.dtype = torch.float32
|
||||||
|
|
|
@ -132,7 +132,13 @@ class MHA(nn.Module):
|
||||||
qkv = self.rotary_emb(qkv, **kwargs)
|
qkv = self.rotary_emb(qkv, **kwargs)
|
||||||
|
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
context = self.inner_attn(qkv)
|
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||||
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
|
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
|
qkv = qkv.to(torch.bfloat16)
|
||||||
|
context = self.inner_attn(qkv).to(x.dtype)
|
||||||
|
else:
|
||||||
|
context = self.inner_attn(qkv)
|
||||||
else:
|
else:
|
||||||
q = qkv[:, :, 0]
|
q = qkv[:, :, 0]
|
||||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||||
|
@ -164,7 +170,14 @@ class MHA(nn.Module):
|
||||||
kwargs.pop("indexes")
|
kwargs.pop("indexes")
|
||||||
|
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
context = self.inner_attn(qkv, **kwargs)
|
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||||
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
|
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
|
qkv = qkv.to(torch.bfloat16)
|
||||||
|
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
|
||||||
|
else:
|
||||||
|
context = self.inner_attn(qkv, **kwargs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Not support this right now")
|
raise RuntimeError("Not support this right now")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue