mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)
parent
a799ca343b
commit
341263df48
|
@ -2,10 +2,10 @@ import time
|
|||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from utils import get_defualt_parser, inference, print_output
|
||||
from utils import get_default_parser, inference, print_output
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_defualt_parser()
|
||||
parser = get_default_parser()
|
||||
args = parser.parse_args()
|
||||
start = time.time()
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
import torch
|
||||
from grok1_policy import Grok1ForCausalLMPolicy
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from utils import get_defualt_parser, inference, print_output
|
||||
from utils import get_default_parser, inference, print_output
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
|
@ -13,7 +13,7 @@ from colossalai.lazy import LazyInitContext
|
|||
from colossalai.utils import get_current_device
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_defualt_parser()
|
||||
parser = get_default_parser()
|
||||
args = parser.parse_args()
|
||||
start = time.time()
|
||||
colossalai.launch_from_torch({})
|
||||
|
|
|
@ -33,7 +33,7 @@ def inference(model, tokenizer, text, **generate_kwargs):
|
|||
return outputs[0].tolist()
|
||||
|
||||
|
||||
def get_defualt_parser():
|
||||
def get_default_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1")
|
||||
parser.add_argument("--tokenizer", type=str, default="tokenizer.model")
|
||||
|
|
Loading…
Reference in New Issue