mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish test component gpt code style (#1567)
parent
6159d45417
commit
e615cfc3a8
|
@ -47,8 +47,15 @@ class GPTLMModel(nn.Module):
|
||||||
# Only return lm_logits
|
# Only return lm_logits
|
||||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||||
|
|
||||||
|
|
||||||
def gpt2_micro(checkpoint=True):
|
def gpt2_micro(checkpoint=True):
|
||||||
return GPTLMModel(checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128)
|
return GPTLMModel(checkpoint=checkpoint,
|
||||||
|
hidden_size=32,
|
||||||
|
num_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
max_seq_len=64,
|
||||||
|
vocab_size=128)
|
||||||
|
|
||||||
|
|
||||||
def gpt2_s(checkpoint=True):
|
def gpt2_s(checkpoint=True):
|
||||||
return GPTLMModel(checkpoint=checkpoint)
|
return GPTLMModel(checkpoint=checkpoint)
|
||||||
|
|
Loading…
Reference in New Issue