mirror of https://github.com/InternLM/InternLM
				
				
				
			test(model): support fp32 with flash_attn (#223)
* support tf32 with flash * move autocast to attention * fix lint * fix lint * fix lint * fix lint * fix some bugs in model * modify the convert dtypepull/222/head
							parent
							
								
									fd28bcab58
								
							
						
					
					
						commit
						eee93b5a68
					
				| 
						 | 
				
			
			@ -120,7 +120,7 @@ model = dict(
 | 
			
		|||
    num_layers=NUM_LAYER,
 | 
			
		||||
    mlp_ratio=MLP_RATIO,
 | 
			
		||||
    apply_post_layer_norm=False,
 | 
			
		||||
    dtype="torch.bfloat16",
 | 
			
		||||
    dtype="torch.tf32", # dtype could be in "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32",
 | 
			
		||||
    norm_type="rmsnorm",
 | 
			
		||||
    layer_norm_epsilon=1e-5,
 | 
			
		||||
    use_flash_attn=True,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -220,10 +220,8 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
 | 
			
		|||
        elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
 | 
			
		||||
            gpc.config.model.dtype = torch.float16
 | 
			
		||||
        elif gpc.config.model.dtype == "torch.float32":
 | 
			
		||||
            assert gpc.config.model.use_flash_attn is False, "when using float32, the use_flash_attn must be False"
 | 
			
		||||
            gpc.config.model.dtype = torch.float32
 | 
			
		||||
        elif gpc.config.model.dtype == "torch.tf32":
 | 
			
		||||
            assert gpc.config.model.use_flash_attn is False, "when using tf32, the use_flash_attn must be False"
 | 
			
		||||
            torch.backends.cudnn.allow_tf32 = True
 | 
			
		||||
            torch.backends.cuda.matmul.allow_tf32 = True
 | 
			
		||||
            gpc.config.model.dtype = torch.float32
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -132,7 +132,13 @@ class MHA(nn.Module):
 | 
			
		|||
            qkv = self.rotary_emb(qkv, **kwargs)
 | 
			
		||||
 | 
			
		||||
        if inference_params is None:
 | 
			
		||||
            context = self.inner_attn(qkv)
 | 
			
		||||
            if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
 | 
			
		||||
                with torch.cuda.amp.autocast(dtype=torch.float16):
 | 
			
		||||
                    if qkv.dtype not in [torch.float16, torch.bfloat16]:
 | 
			
		||||
                        qkv = qkv.to(torch.bfloat16)
 | 
			
		||||
                    context = self.inner_attn(qkv).to(x.dtype)
 | 
			
		||||
            else:
 | 
			
		||||
                context = self.inner_attn(qkv)
 | 
			
		||||
        else:
 | 
			
		||||
            q = qkv[:, :, 0]
 | 
			
		||||
            assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
 | 
			
		||||
| 
						 | 
				
			
			@ -164,7 +170,14 @@ class MHA(nn.Module):
 | 
			
		|||
        kwargs.pop("indexes")
 | 
			
		||||
 | 
			
		||||
        if inference_params is None:
 | 
			
		||||
            context = self.inner_attn(qkv, **kwargs)
 | 
			
		||||
            if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
 | 
			
		||||
                with torch.cuda.amp.autocast(dtype=torch.float16):
 | 
			
		||||
                    if qkv.dtype not in [torch.float16, torch.bfloat16]:
 | 
			
		||||
                        qkv = qkv.to(torch.bfloat16)
 | 
			
		||||
                    context = self.inner_attn(qkv, **kwargs).to(x.dtype)
 | 
			
		||||
            else:
 | 
			
		||||
                context = self.inner_attn(qkv, **kwargs)
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            raise RuntimeError("Not support this right now")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue