|
|
@ -3,7 +3,7 @@ import time
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from grok1_policy import Grok1ForCausalLMPolicy
|
|
|
|
from grok1_policy import Grok1ForCausalLMPolicy
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from utils import get_defualt_parser, inference, print_output
|
|
|
|
from utils import get_default_parser, inference, print_output
|
|
|
|
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
import colossalai
|
|
|
|
from colossalai.booster import Booster
|
|
|
|
from colossalai.booster import Booster
|
|
|
@ -13,7 +13,7 @@ from colossalai.lazy import LazyInitContext
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = get_defualt_parser()
|
|
|
|
parser = get_default_parser()
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
start = time.time()
|
|
|
|
start = time.time()
|
|
|
|
colossalai.launch_from_torch({})
|
|
|
|
colossalai.launch_from_torch({})
|
|
|
|