Browse Source

[example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme
pull/5486/head
Hongxin Liu 8 months ago committed by GitHub
parent
commit
848a574c26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 43
      examples/language/grok-1/README.md
  2. 99
      examples/language/grok-1/grok1_policy.py
  3. 32
      examples/language/grok-1/inference.py
  4. 50
      examples/language/grok-1/inference_tp.py
  5. 4
      examples/language/grok-1/requirements.txt
  6. 11
      examples/language/grok-1/run_inference_fast.sh
  7. 11
      examples/language/grok-1/run_inference_slow.sh
  8. 1
      examples/language/grok-1/test_ci.sh
  9. 46
      examples/language/grok-1/utils.py

43
examples/language/grok-1/README.md

@ -0,0 +1,43 @@
# Grok-1 Inference
## Install
```bash
# Make sure you install colossalai from the latest source code
git clone https://github.com/hpcaitech/ColossalAI.git
cd ColossalAI
pip install .
cd examples/language/grok-1
pip install -r requirements.txt
```
## Tokenizer preparation
You should download the tokenizer from the official grok-1 repository.
```bash
wget https://github.com/xai-org/grok-1/raw/main/tokenizer.model
```
## Inference
You need 8x A100 80GB or equivalent GPUs to run the inference.
We provide two scripts for inference. `run_inference_fast.sh` uses tensor parallelism provided by ColossalAI, and it is faster. `run_inference_slow.sh` uses auto device provided by transformers, and it is slower.
Command format:
```bash
./run_inference_fast.sh <model_name_or_path> <tokenizer_path>
./run_inference_slow.sh <model_name_or_path> <tokenizer_path>
```
`model_name_or_path` can be a local path or a model name from Hugging Face model hub. We provided weights on model hub, named `hpcaitech/grok-1`.
Command example:
```bash
./run_inference_fast.sh hpcaitech/grok-1 tokenizer.model
```
It will take 5-10 minutes to load checkpoints. Don't worry, it's not stuck.

99
examples/language/grok-1/grok1_policy.py

@ -0,0 +1,99 @@
from typing import Dict, Union
import torch.nn as nn
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class Grok1Policy(Policy):
def config_sanity_check(self):
pass
def preprocess(self) -> nn.Module:
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
assert vocab_size % world_size == 0, f"vocab_size {vocab_size} must be divisible by world_size {world_size}"
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
decoder_submodule_replacement = [
SubModuleReplacementDescription(
suffix="attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attn.o_proj",
target_module=Linear1D_Row,
),
]
for i in range(self.model.config.num_experts):
decoder_submodule_replacement.extend(
[
SubModuleReplacementDescription(
suffix=f"moe_block.experts[{i}].linear",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix=f"moe_block.experts[{i}].linear_v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix=f"moe_block.experts[{i}].linear_1",
target_module=Linear1D_Row,
),
]
)
policy["DecoderLayer"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=decoder_submodule_replacement,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key="Grok1Model",
)
return policy
def postprocess(self):
return self.model
class Grok1ModelPolicy(Grok1Policy):
pass
class Grok1ForCausalLMPolicy(Grok1Policy):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = super().module_policy()
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
),
policy=policy,
target_key="Grok1ModelForCausalLM",
)
return policy

32
examples/language/grok-1/inference.py

@ -0,0 +1,32 @@
import time
import torch
from sentencepiece import SentencePieceProcessor
from transformers import AutoModelForCausalLM
from utils import get_defualt_parser, inference, print_output
if __name__ == "__main__":
parser = get_defualt_parser()
args = parser.parse_args()
start = time.time()
torch.set_default_dtype(torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16,
)
sp = SentencePieceProcessor(model_file=args.tokenizer)
for text in args.text:
output = inference(
model,
sp,
text,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
print_output(text, sp.decode(output))
print(f"Overall time: {time.time() - start} seconds.")

50
examples/language/grok-1/inference_tp.py

@ -0,0 +1,50 @@
import time
import torch
from grok1_policy import Grok1ForCausalLMPolicy
from sentencepiece import SentencePieceProcessor
from transformers import AutoModelForCausalLM
from utils import get_defualt_parser, inference, print_output
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.utils import get_current_device
if __name__ == "__main__":
parser = get_defualt_parser()
args = parser.parse_args()
start = time.time()
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
plugin = HybridParallelPlugin(
tp_size=coordinator.world_size,
pp_size=1,
precision="bf16",
parallel_output=False,
custom_policy=Grok1ForCausalLMPolicy(),
)
booster = Booster(plugin=plugin)
torch.set_default_dtype(torch.bfloat16)
with LazyInitContext(default_device=get_current_device()):
model = AutoModelForCausalLM.from_pretrained(
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
)
model, *_ = booster.boost(model)
sp = SentencePieceProcessor(model_file=args.tokenizer)
for text in args.text:
output = inference(
model.unwrap(),
sp,
text,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
if coordinator.is_master():
print_output(text, sp.decode(output))
coordinator.print_on_master(f"Overall time: {time.time() - start} seconds.")

4
examples/language/grok-1/requirements.txt

@ -0,0 +1,4 @@
torch>=2.1.0,<2.2.0
colossalai>=0.3.6
sentencepiece==0.1.99
transformers==4.35.0

11
examples/language/grok-1/run_inference_fast.sh

@ -0,0 +1,11 @@
#!/usr/bin/env bash
PRETRAINED=${1:-"hpcaitech/grok-1"}
TOKENIZER=${2:-"tokenizer.model"}
torchrun --standalone --nproc_per_node 8 inference_tp.py --pretrained "$PRETRAINED" \
--tokenizer "$TOKENIZER" \
--max_new_tokens 64 \
--text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \
"将以下句子翻译成英语。 我喜欢看电影和读书。" \
"All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?"

11
examples/language/grok-1/run_inference_slow.sh

@ -0,0 +1,11 @@
#!/usr/bin/env bash
PRETRAINED=${1:-"hpcaitech/grok-1"}
TOKENIZER=${2:-"tokenizer.model"}
python3 inference.py --pretrained "$PRETRAINED" \
--tokenizer "$TOKENIZER" \
--max_new_tokens 64 \
--text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \
"将以下句子翻译成英语。 我喜欢看电影和读书。" \
"All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?"

1
examples/language/grok-1/test_ci.sh

@ -0,0 +1 @@
pip install -r requirements.txt

46
examples/language/grok-1/utils.py

@ -0,0 +1,46 @@
import argparse
import torch
class Bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def print_output(text, output):
print(f"-----\n{Bcolors.OKBLUE}{text}{Bcolors.ENDC}{output[len(text):]}")
@torch.no_grad()
def inference(model, sp, text, **generate_kwargs):
input_ids = sp.encode(text)
input_ids = torch.tensor([input_ids]).cuda()
attention_mask = torch.ones_like(input_ids)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
**generate_kwargs,
}
outputs = model.generate(**inputs)
return outputs[0].tolist()
def get_defualt_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1")
parser.add_argument("--tokenizer", type=str, default="tokenizer.model")
parser.add_argument("--text", type=str, nargs="+", default=["Hi, what's your name?"])
parser.add_argument("--max_new_tokens", type=int, default=30)
parser.add_argument("--do_sample", action="store_true", default=False)
parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation")
return parser
Loading…
Cancel
Save