mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish test component gpt code style (#1567)
parent
6159d45417
commit
e615cfc3a8
|
@ -47,8 +47,15 @@ class GPTLMModel(nn.Module):
|
|||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
def gpt2_micro(checkpoint=True):
|
||||
return GPTLMModel(checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128)
|
||||
return GPTLMModel(checkpoint=checkpoint,
|
||||
hidden_size=32,
|
||||
num_layers=2,
|
||||
num_attention_heads=4,
|
||||
max_seq_len=64,
|
||||
vocab_size=128)
|
||||
|
||||
|
||||
def gpt2_s(checkpoint=True):
|
||||
return GPTLMModel(checkpoint=checkpoint)
|
||||
|
|
Loading…
Reference in New Issue