[example] update Grok-1 inference (#5495)

* revise grok-1 example

* remove unused arg in scripts

* prevent re-installing torch

* update readme

* revert modifying colossalai requirements

* add perf

* trivial

* add tokenizer url
pull/5480/head
Yuanheng Zhao 2024-03-24 20:24:11 +08:00 committed by GitHub
parent 6df844b8c4
commit 5fcd7795cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 69 additions and 43 deletions

View File

@ -5,7 +5,7 @@ An easy-to-use Python + PyTorch + HuggingFace version of 314B Grok-1.
[[blog]](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available)
[[HuggingFace Grok-1 PyTorch model weights]](https://huggingface.co/hpcai-tech/grok-1)
## Install
## Installation
```bash
# Make sure you install colossalai from the latest source code
@ -16,33 +16,36 @@ cd examples/language/grok-1
pip install -r requirements.txt
```
## Tokenizer preparation
You should download the tokenizer from the official grok-1 repository.
```bash
wget https://github.com/xai-org/grok-1/raw/main/tokenizer.model
```
## Inference
You need 8x A100 80GB or equivalent GPUs to run the inference.
We provide two scripts for inference. `run_inference_fast.sh` uses tensor parallelism provided by ColossalAI, and it is faster. `run_inference_slow.sh` uses auto device provided by transformers, and it is slower.
Command format:
```bash
./run_inference_fast.sh <model_name_or_path> <tokenizer_path>
./run_inference_slow.sh <model_name_or_path> <tokenizer_path>
```
`model_name_or_path` can be a local path or a model name from Hugging Face model hub. We provided weights on model hub, named `hpcaitech/grok-1`.
We provide two scripts for inference. `run_inference_fast.sh` uses tensor parallelism provided by ColossalAI, which is faster for generation, while `run_inference_slow.sh` uses auto device provided by transformers, which is relatively slower.
Command example:
```bash
./run_inference_fast.sh hpcaitech/grok-1 tokenizer.model
./run_inference_fast.sh <MODEL_NAME_OR_PATH>
./run_inference_slow.sh <MODEL_NAME_OR_PATH>
```
It will take 5-10 minutes to load checkpoints. Don't worry, it's not stuck.
`MODEL_NAME_OR_PATH` can be a model name from Hugging Face model hub or a local path to PyTorch-version model checkpoints. We provided weights on model hub, named `hpcaitech/grok-1`. And you could also download the weights in advance using `git`:
```bash
git lfs install
git clone https://huggingface.co/hpcai-tech/grok-1
```
It will take, depending on your Internet speed, several hours to tens of hours to download checkpoints (about 600G!), and 5-10 minutes to load checkpoints when it's ready to launch the inference. Don't worry, it's not stuck.
## Performance
For request of batch size set to 1 and maximum length set to 100:
| Method | Initialization-Duration(sec) | Average-Generation-Latency(sec) |
|-------------------------|------------------------------|---------------------------------|
| ColossalAI | 431.45 | 14.92 |
| HuggingFace Auto-Device | 426.96 | 48.38 |
| JAX | 147.61 | 56.25 |
Tested on 8x80G NVIDIA H800.

View File

@ -1,8 +1,7 @@
import time
import torch
from sentencepiece import SentencePieceProcessor
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
from utils import get_defualt_parser, inference, print_output
if __name__ == "__main__":
@ -16,11 +15,17 @@ if __name__ == "__main__":
device_map="auto",
torch_dtype=torch.bfloat16,
)
sp = SentencePieceProcessor(model_file=args.tokenizer)
model.eval()
init_time = time.time() - start
# A transformers-compatible version of the grok-1 tokenizer by Xenova
# https://huggingface.co/Xenova/grok-1-tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
for text in args.text:
output = inference(
model,
sp,
tokenizer,
text,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
@ -28,5 +33,14 @@ if __name__ == "__main__":
top_k=args.top_k,
top_p=args.top_p,
)
print_output(text, sp.decode(output))
print(f"Overall time: {time.time() - start} seconds.")
print_output(text, tokenizer.decode(output))
overall_time = time.time() - start
gen_latency = overall_time - init_time
avg_gen_latency = gen_latency / len(args.text)
print(
f"Initializing time: {init_time:.2f} seconds.\n"
f"Overall time: {overall_time:.2f} seconds. \n"
f"Generation latency: {gen_latency:.2f} seconds. \n"
f"Average generation latency: {avg_gen_latency:.2f} seconds. \n"
)

View File

@ -2,8 +2,7 @@ import time
import torch
from grok1_policy import Grok1ForCausalLMPolicy
from sentencepiece import SentencePieceProcessor
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
from utils import get_defualt_parser, inference, print_output
import colossalai
@ -33,11 +32,17 @@ if __name__ == "__main__":
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
)
model, *_ = booster.boost(model)
sp = SentencePieceProcessor(model_file=args.tokenizer)
model.eval()
init_time = time.time() - start
# A transformers-compatible version of the grok-1 tokenizer by Xenova
# https://huggingface.co/Xenova/grok-1-tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
for text in args.text:
output = inference(
model.unwrap(),
sp,
tokenizer,
text,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
@ -46,5 +51,14 @@ if __name__ == "__main__":
top_p=args.top_p,
)
if coordinator.is_master():
print_output(text, sp.decode(output))
coordinator.print_on_master(f"Overall time: {time.time() - start} seconds.")
print_output(text, tokenizer.decode(output))
overall_time = time.time() - start
gen_latency = overall_time - init_time
avg_gen_latency = gen_latency / len(args.text)
coordinator.print_on_master(
f"Initializing time: {init_time:.2f} seconds.\n"
f"Overall time: {overall_time:.2f} seconds. \n"
f"Generation latency: {gen_latency:.2f} seconds. \n"
f"Average generation latency: {avg_gen_latency:.2f} seconds. \n"
)

View File

@ -1,4 +1,3 @@
torch>=2.1.0,<2.2.0
colossalai>=0.3.6
sentencepiece==0.1.99
transformers==4.35.0

View File

@ -1,11 +1,9 @@
#!/usr/bin/env bash
PRETRAINED=${1:-"hpcaitech/grok-1"}
TOKENIZER=${2:-"tokenizer.model"}
torchrun --standalone --nproc_per_node 8 inference_tp.py --pretrained "$PRETRAINED" \
--tokenizer "$TOKENIZER" \
--max_new_tokens 64 \
--max_new_tokens 100 \
--text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \
"将以下句子翻译成英语。 我喜欢看电影和读书。" \
"All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?"

View File

@ -1,11 +1,9 @@
#!/usr/bin/env bash
PRETRAINED=${1:-"hpcaitech/grok-1"}
TOKENIZER=${2:-"tokenizer.model"}
python3 inference.py --pretrained "$PRETRAINED" \
--tokenizer "$TOKENIZER" \
--max_new_tokens 64 \
--max_new_tokens 100 \
--text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \
"将以下句子翻译成英语。 我喜欢看电影和读书。" \
"All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?"

View File

@ -20,9 +20,9 @@ def print_output(text, output):
@torch.no_grad()
def inference(model, sp, text, **generate_kwargs):
input_ids = sp.encode(text)
input_ids = torch.tensor([input_ids]).cuda()
def inference(model, tokenizer, text, **generate_kwargs):
input_ids = tokenizer(text, return_tensors="pt").input_ids
input_ids = input_ids.cuda()
attention_mask = torch.ones_like(input_ids)
inputs = {
"input_ids": input_ids,