[NFC] polish test component gpt code style (#1567)

pull/1574/head
アマデウス 2022-09-08 16:34:09 +08:00 committed by GitHub
parent 6159d45417
commit e615cfc3a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 1 deletions

View File

@ -47,8 +47,15 @@ class GPTLMModel(nn.Module):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
def gpt2_micro(checkpoint=True):
return GPTLMModel(checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128)
return GPTLMModel(checkpoint=checkpoint,
hidden_size=32,
num_layers=2,
num_attention_heads=4,
max_seq_len=64,
vocab_size=128)
def gpt2_s(checkpoint=True):
return GPTLMModel(checkpoint=checkpoint)