ColossalAI/colossalai/elixir/kernels/attn_wrapper.py

21 lines
756 B
Python

import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder
from .gpt_attention import XGPT2Attention, XGPT2Model
from .opt_attention import XOPTAttention, XOPTDecoder
def wrap_attention(model: nn.Module):
for name, module in model.named_modules():
if isinstance(module, GPT2Model):
module.__class__ = XGPT2Model
elif isinstance(module, GPT2Attention):
module.__class__ = XGPT2Attention
elif isinstance(module, OPTAttention):
module.__class__ = XOPTAttention
elif isinstance(module, OPTDecoder):
module.__class__ = XOPTDecoder
return model