mirror of https://github.com/hpcaitech/ColossalAI
21 lines
756 B
Python
21 lines
756 B
Python
import torch
|
|
import torch.nn as nn
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model
|
|
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder
|
|
|
|
from .gpt_attention import XGPT2Attention, XGPT2Model
|
|
from .opt_attention import XOPTAttention, XOPTDecoder
|
|
|
|
|
|
def wrap_attention(model: nn.Module):
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, GPT2Model):
|
|
module.__class__ = XGPT2Model
|
|
elif isinstance(module, GPT2Attention):
|
|
module.__class__ = XGPT2Attention
|
|
elif isinstance(module, OPTAttention):
|
|
module.__class__ = XOPTAttention
|
|
elif isinstance(module, OPTDecoder):
|
|
module.__class__ = XOPTDecoder
|
|
return model
|