ColossalAI/tests/test_elixir/utils/opt.py

47 lines
1.2 KiB
Python

import torch.nn as nn
from transformers import OPTConfig, OPTForCausalLM
from tests.test_elixir.utils.registry import TEST_MODELS
from .gpt import micro_data_fn
class OPTLMModel(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.config = config
self.module = OPTForCausalLM(config=config)
self.enable_gc = False
def gradient_checkpointing_enable(self):
self.module.gradient_checkpointing_enable()
self.enable_gc = True
def forward(self, input_ids, attention_mask):
loss = self.module(
# pre-commit: do not rearrange
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
use_cache=(not self.enable_gc))['loss']
return loss
def opt_micro():
opt_config = OPTConfig(
# pre-commit: do not rearrange
vocab_size=128,
activation_dropout=0.0,
dropout=0,
hidden_size=32,
num_hidden_layers=4,
ffn_dim=128,
num_attention_heads=4,
word_embed_proj_dim=32,
output_projection=True)
return OPTLMModel(opt_config)
TEST_MODELS.register('opt_micro', opt_micro, micro_data_fn)