diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py index a991e8558..7a47624f7 100644 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -101,6 +101,11 @@ def main(args): initial_model = deepcopy(actor).cuda().half() reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda().half() + if args.use_kernels: + from coati.kernels import convert_to_xformer_model + actor, critic, initial_model, reward_model = map(convert_to_xformer_model, + (actor, critic, initial_model, reward_model)) + actor_numel = get_model_numel(actor, strategy) critic_numel = get_model_numel(critic, strategy) initial_model_numel = get_model_numel(initial_model, strategy) @@ -184,5 +189,6 @@ if __name__ == '__main__': parser.add_argument('--lora_rank', type=int, default=0) parser.add_argument('--cuda_mem_frac', type=float, default=1.0) parser.add_argument('--offload_inference_models', action='store_true', default=False) + parser.add_argument('--use_kernels', action='store_true', default=False) args = parser.parse_args() main(args) diff --git a/applications/Chat/coati/kernels/__init__.py b/applications/Chat/coati/kernels/__init__.py new file mode 100644 index 000000000..230eedf7e --- /dev/null +++ b/applications/Chat/coati/kernels/__init__.py @@ -0,0 +1,6 @@ +from .wrapper import convert_to_xformer_model, recover_from_xformer_model + +__all__ = [ + 'convert_to_xformer_model', + 'recover_from_xformer_model', +] diff --git a/applications/Chat/coati/kernels/opt_attn.py b/applications/Chat/coati/kernels/opt_attn.py new file mode 100644 index 000000000..c10f341e9 --- /dev/null +++ b/applications/Chat/coati/kernels/opt_attn.py @@ -0,0 +1,87 @@ +from typing import Optional, Tuple + +import torch +import xformers.ops as xops +from torch import Tensor +from transformers.models.opt.modeling_opt import OPTAttention + + +# This is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py +class XOPTAttention(OPTAttention): + # def _shape(self, tensor: Tensor, seq_len: int, bsz: int): + # return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: Tensor, + key_value_states: Optional[Tensor] = None, + past_key_value: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + layer_head_mask: Optional[Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]: + if not self.training: + return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, + output_attentions) + """Input shape: Batch x Time x Channel""" + assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask' + assert not output_attentions, 'Xformers attention does not support output_attentions' + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz).transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = xops.memory_efficient_attention(query_states, + key_states, + value_states, + attn_bias=xops.LowerTriangularMask(), + p=self.dropout if self.training else 0.0, + scale=self.scaling) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value diff --git a/applications/Chat/coati/kernels/wrapper.py b/applications/Chat/coati/kernels/wrapper.py new file mode 100644 index 000000000..c55bda600 --- /dev/null +++ b/applications/Chat/coati/kernels/wrapper.py @@ -0,0 +1,18 @@ +import torch.nn as nn +from transformers.models.opt.modeling_opt import OPTAttention + +from .opt_attn import XOPTAttention + + +def convert_to_xformer_model(model: nn.Module) -> nn.Module: + for module in model.modules(): + if isinstance(module, OPTAttention): + module.__class__ = XOPTAttention + return model + + +def recover_from_xformer_model(model: nn.Module) -> nn.Module: + for module in model.modules(): + if isinstance(module, XOPTAttention): + module.__class__ = OPTAttention + return model