diff --git a/.github/workflows/e2e_test.yaml b/.github/workflows/e2e_test.yaml new file mode 100644 index 0000000..155c8ad --- /dev/null +++ b/.github/workflows/e2e_test.yaml @@ -0,0 +1,56 @@ +name: e2e-tests +on: + pull_request: + branches: + - "main" + - "develop" + paths-ignore: + - "doc/**" + - "**.md" +env: + WORKSPACE_PREFIX: $(echo $GITHUB_WORKSPACE |cut -d '/' -f 1-4) + SLURM_PARTITION: llm + +jobs: + check-requirements: + runs-on: [lmtest] + steps: + - name: mask env + run: | + echo "::add-mask::${{env.WORKSPACE_PREFIX}}" + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + - name: check-requirements + run: | + source activate internlm-env-test + changed_files=$(git diff --name-only -r HEAD^1 HEAD) + echo $changed_files + if [[ $changed_files =~ "runtime.txt" ]]; then + pip install -r requirements/runtime.txt + fi + + if [[ $changed_files =~ "torch.txt" ]]; then + pip install -r requirements/torch.txt + fi + + + e2e_tests: + if: ${{ always() }} + needs: check-requirements + runs-on: [lmtest] + timeout-minutes: 30 + steps: + - name: mask env + run: | + echo "::add-mask::${{env.WORKSPACE_PREFIX}}" + - uses: actions/checkout@v3 + + - name: e2e-test + run: | + source activate internlm-env-test + srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_8GPU" ./tests/test_training + srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2TP" ./tests/test_training + srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2TPSP" ./tests/test_training + srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2PP" ./tests/test_training + srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2PP_InterleavedOverlap" ./tests/test_training diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index fa27a2d..a060f47 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -263,6 +263,12 @@ def args_sanity_check(): gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False ), "sequence parallel does not support use_flash_attn=False" + # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy + if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: + assert ( + gpc.config.parallel["pipeline"].get("interleaved_overlap", False) is True + ), "only support interleaved pipeline scheduler with overlap" + # monitoring default config monitor_default_config = { "alert_address": None, # compatible with old alert config diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 5a3a4eb..d18308a 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -4,14 +4,13 @@ from typing import Optional import torch -import torch.nn.functional as F from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from flash_attn.utils.distributed import all_reduce, reduce_scatter from torch import nn from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import fused_dense_func_torch +from internlm.model.utils import Silu, fused_dense_func_torch class ScaleColumnParallelLinear(nn.Linear): @@ -197,5 +196,7 @@ class FeedForward(nn.Module): ) def forward(self, x): - out = self.w3(F.silu(self.w1(x)) * self.w2(x)) + w1_o = self.w1(x) + w2_o = self.w2(x) + out = self.w3(Silu(w1_o, w2_o)) return out diff --git a/internlm/model/utils.py b/internlm/model/utils.py index bb887c3..570a86f 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -225,3 +225,10 @@ def is_norm_param(param: torch.Tensor) -> bool: if hasattr(param, "is_norm") and param.is_norm: return True return False + + +def Silu(w1_o, w2_o): + return F.silu(w1_o) * w2_o + + +Silu = torch.jit.script(Silu) diff --git a/tests/test_training/__init__.py b/tests/test_training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py new file mode 100644 index 0000000..6c9d828 --- /dev/null +++ b/tests/test_training/test_loss.py @@ -0,0 +1,390 @@ +import math +import subprocess + +import pytest +import torch +import torch.distributed as dist + +import internlm +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.scheduler import SchedulerMetricHook +from internlm.core.trainer import TrainState +from internlm.initialize import initialize_distributed_env +from internlm.model.loss import FlashGPTLMLoss +from internlm.model.metrics import AccPerplex +from internlm.train import ( + get_train_data_loader, + initialize_model, + initialize_optimizer, + load_new_batch, +) +from internlm.utils.common import BatchSkipper, launch_time +from internlm.utils.gputest import empty_cache_and_diag +from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.model_checkpoint import CheckpointManager + +CONFIG_FILE_PATH = "./configs/7B_sft.py" +TOTAL_STEPS = 10 +LOSS_SPIKE_LIMIT = 1.5 +LOSS_DEVIATION_LIMIT = 0.2 +BASELINE_LOSS_LIST = [ + 11.64188003540039, + 7.9205322265625, + 6.944362163543701, + 6.147305488586426, + 6.060564994812012, + 5.660439491271973, + 5.19430685043335, + 5.157323837280273, + 4.769168376922607, + 4.449280738830566, +] +cur_loss_list = [] + + +def train(): + # initialize distributed environment + initialize_distributed_env(config=CONFIG_FILE_PATH) + assert hasattr(gpc, "config") and gpc.config is not None + + # init setting + gpc.config.data.total_steps = TOTAL_STEPS + gpc.config.lr_scheduler.total_steps = TOTAL_STEPS + total_steps = gpc.config.data.total_steps + skip_batches = gpc.config.data.skip_batches + label_smoothing = gpc.config.loss.label_smoothing + + # get and broadcast current time + current_time = launch_time() + objs = [current_time] + dist.broadcast_object_list(objs, src=0) + current_time = objs[0] + + # initialize model + model = initialize_model() + + # initialize loss function + criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) + + # initialize the train data loader + train_dl, dataset_types = get_train_data_loader(num_worker=4) + + # initialize and resume train state + train_state = TrainState(gpc.config, train_dl.batch_sampler) + + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) + + with open(CONFIG_FILE_PATH, "r") as f: + config_lines = f.readlines() + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dl=train_dl, + model_config=gpc.config.model, + model_config_file="".join(config_lines), + feishu_address=gpc.config.monitor.alert.feishu_alert_address, + ) + + # Loading other persistent training states. + ckpt_manager.try_resume_training(train_state, current_time) + + # initialize metric for calculating accuracy and perplexity + metric = AccPerplex( + device=torch.cuda.current_device(), + tp_pg=gpc.get_group(ParallelMode.TENSOR), + dp_pg=gpc.get_group(ParallelMode.DATA), + dataset_types=dataset_types, + ) + + # initialize trainer + scheduler_hooks = [ + SchedulerMetricHook( + metric=metric, + skip=( + gpc.is_using_pp() + and hasattr(gpc.config.model, "num_chunks") + and gpc.config.model.num_chunks > 1 + and gpc.config.parallel["pipeline"].get("interleaved_overlap", False) + ), + ), + ] + + trainer, train_dl, _, _ = internlm.initialize_trainer( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dl, + lr_scheduler=lr_scheduler, + beta2_scheduler=beta2_scheduler, + scheduler_hooks=scheduler_hooks, + ) + + # initialize the batch skipper + batch_skipper = BatchSkipper(skip_batches) + + trainer.train() + + # transfer the train data loader into train data iterator + train_iter = iter(train_dl) + + # start iterating the train data and begin training + for batch_count in range(train_state.batch_count, total_steps): + empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) + timer("one-batch").start() + + # load batch data + batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) + + # record the consumed samples in training + train_state.batch_count = batch_count + train_state.num_consumed_samples_in_epoch += len(batch[1]) + if batch_skipper(batch_count): # skip this batch + if gpc.is_rank_for_log(): + print(f"Skip batch count:`{batch_count}`...") + timer("one-batch").stop() + continue + + # zero the grads of parameters + trainer.zero_grad() + # process data + if batch[0].get("type_ids", None) is not None: + metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) + + # do forward and backward + timer("fwd-bwd").start() + + _, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False) + if gpc.is_rank_for_log(): + assert loss is not None and not math.isnan(loss.item()) + global cur_loss_list + cur_loss_list.append(loss.item()) + timer("fwd-bwd").stop() + + # update parameters, and returns (success_update, grad_norm) + trainer_result = trainer.step() + assert trainer_result is not None + + success_update, _ = trainer_result + assert success_update, "Error: grad norm inf or nan occurs!" + if success_update: # update parameters successfully + train_state.step_count += 1 + else: + train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. + + timer("one-batch").stop() + + +def check_loss_spike(): + if gpc.is_rank_for_log(): + for step in range(1, TOTAL_STEPS): + assert ( + cur_loss_list[step] < cur_loss_list[step - 1] * LOSS_SPIKE_LIMIT + ), f"The loss spike occurs, {cur_loss_list[step - 1]}->{cur_loss_list[step]}, please check it!" + + +def check_loss_accuracy(): + if gpc.is_rank_for_log(): + for cur, target in zip(cur_loss_list, BASELINE_LOSS_LIST): + assert ( + abs(cur - target) < LOSS_DEVIATION_LIMIT + ), f"The loss accuracy is abnormal, {target}->{cur}, please check it!" + + +class TestCaseTrain8GPU: + """ + Test cases for Model Training with 8 GPUs. + Parallel Config: + data parallel size = 8. + """ + + @staticmethod + def setup_class(): + # model training + train() + + # print loss value + print(f"cur_loss_list: {cur_loss_list}", flush=True) + + @staticmethod + @pytest.mark.training_8GPU + def test_loss_spike_with_dp8(): + check_loss_spike() + + @staticmethod + @pytest.mark.training_8GPU + def test_loss_accuracy_with_dp8(): + check_loss_accuracy() + + +class TestCaseTrain16GPUWith8DP2TP: + """ + Test cases for Model Training with 16 GPUs. + Parallel Config: + data parallel size = 8. + tensor parallel size = 2. + """ + + @staticmethod + def setup_class(): + # update config tensor parallel size + command = f"sed -i 's/^.*tensor=.*/ tensor=2,/' {CONFIG_FILE_PATH}" + subprocess.run(command, shell=True, check=True) + + # model training + train() + + # print loss value + print(f"cur_loss_list: {cur_loss_list}", flush=True) + + @staticmethod + @pytest.mark.training_16GPU_8DP2TP + def test_loss_spike_with_dp8_tp2(): + check_loss_spike() + + @staticmethod + @pytest.mark.training_16GPU_8DP2TP + def test_loss_accuracy_with_dp8_tp2(): + check_loss_accuracy() + + +class TestCaseTrain16GPUWith8DP2TPSP: + """ + Test cases for Model Training with 16 GPUs. + Parallel Config: + data parallel size = 8. + tensor parallel size = 2. + sequence parallel = True. + """ + + @staticmethod + def setup_class(): + # update config tensor parallel size and sequence parallel + command = f"sed -i 's/^.*tensor=.*/ tensor=2,/' {CONFIG_FILE_PATH}" + subprocess.run(command, shell=True, check=True) + command = f"sed -i 's/^.*sequence_parallel=.*/ sequence_parallel=True,/' {CONFIG_FILE_PATH}" + subprocess.run(command, shell=True, check=True) + + # model training + train() + + # print loss value + print(f"cur_loss_list: {cur_loss_list}", flush=True) + + @staticmethod + @pytest.mark.training_16GPU_8DP2TPSP + def test_loss_spike_with_dp8_tp2_sp(): + check_loss_spike() + + @staticmethod + @pytest.mark.training_16GPU_8DP2TPSP + def test_loss_accuracy_with_dp8_tp2_sp(): + check_loss_accuracy() + + +class TestCaseTrain16GPUWith8DP2PP: + """ + Test cases for Model Training with 16 GPUs. + Parallel Config: + data parallel size = 8. + pipeline parallel size = 2. + """ + + @staticmethod + def setup_class(): + # update config pipeline parallel size + command = f"sed -i 's/^.*pipeline=.*/ pipeline=dict(size=2),/' {CONFIG_FILE_PATH}" + subprocess.run(command, shell=True, check=True) + command = f"sed -i 's/^.*tensor=.*/ tensor=1,/' {CONFIG_FILE_PATH}" + subprocess.run(command, shell=True, check=True) + + # model training + train() + + # print loss value + print(f"cur_loss_list: {cur_loss_list}", flush=True) + + @staticmethod + @pytest.mark.training_16GPU_8DP2PP + def test_loss_spike_with_dp8_pp2(): + check_loss_spike() + + @staticmethod + @pytest.mark.training_16GPU_8DP2PP + def test_loss_accuracy_with_dp8_pp2(): + check_loss_accuracy() + + +class TestCaseTrain16GPUWith8DP2PPInterleaved: + """ + Test cases for Model Training with 16 GPUs. + Parallel Config: + data parallel size = 8. + pipeline parallel size = 2. + interleaved scheduler = True. + """ + + @staticmethod + def setup_class(): + # update config pipeline parallel size + command = f"sed -i 's/^.*pipeline=.*/ pipeline=dict(size=2),/' {CONFIG_FILE_PATH}" + subprocess.run(command, shell=True, check=True) + command = f"sed -i 's/^.*num_chunks=.*/ num_chunks=2,/' {CONFIG_FILE_PATH}" + subprocess.run(command, shell=True, check=True) + command = f"sed -i 's/^.*tensor=.*/ tensor=1,/' {CONFIG_FILE_PATH}" + subprocess.run(command, shell=True, check=False) + + # model training + train() + + # print loss value + print(f"cur_loss_list: {cur_loss_list}", flush=True) + + @staticmethod + @pytest.mark.training_16GPU_8DP2PP_Interleaved + def test_loss_spike_with_dp8_pp2_interleaved(): + check_loss_spike() + + @staticmethod + @pytest.mark.training_16GPU_8DP2PP_Interleaved + def test_loss_accuracy_with_dp8_pp2_interleaved(): + check_loss_accuracy() + + +class TestCaseTrain16GPUWith8DP2PPInterleavedOverlap: + """ + Test cases for Model Training with 16 GPUs. + Parallel Config: + data parallel size = 8. + pipeline parallel size = 2. + interleaved scheduler = True. + interleaved overlap = True. + """ + + @staticmethod + def setup_class(): + # update config pipeline parallel size + command = f"sed -i 's/^.*pipeline=.*/ pipeline=dict(size=2, interleaved_overlap=True),/' {CONFIG_FILE_PATH}" + subprocess.run(command, shell=True, check=True) + command = f"sed -i 's/^.*num_chunks=.*/ num_chunks=2,/' {CONFIG_FILE_PATH}" + subprocess.run(command, shell=True, check=True) + command = f"sed -i 's/^.*tensor=.*/ tensor=1,/' {CONFIG_FILE_PATH}" + subprocess.run(command, shell=True, check=True) + + # model training + train() + + # print loss value + print(f"cur_loss_list: {cur_loss_list}", flush=True) + + @staticmethod + @pytest.mark.training_16GPU_8DP2PP_InterleavedOverlap + def test_loss_spike_with_dp8_pp2_interleaved_overlap(): + check_loss_spike() + + @staticmethod + @pytest.mark.training_16GPU_8DP2PP_InterleavedOverlap + def test_loss_accuracy_with_dp8_pp2_interleaved_overlap(): + check_loss_accuracy() diff --git a/tools/README.md b/tools/README.md index 0c78a56..8e42e78 100644 --- a/tools/README.md +++ b/tools/README.md @@ -109,3 +109,29 @@ InternLM 在 GSM8K 数据集中带工具和不带工具的性能表现: | -------- | -------------------- | | w/o tool | 34.5 | | w tool | 39.2 | + +# openai_api.py + +使用 OpenAI 接口实现的流式部署,可以应用于基于 ChatGPT 的应用的后端。部署的命令为: + +```bash +python openai_api.py +``` + +然后可以通过下面代码调用部署好的 api: + +```python +import openai +if __name__ == "__main__": + openai.api_base = "http://localhost:8000/internlm" + openai.api_key = "none" + for chunk in openai.ChatCompletion.create( + model="internlm-chat-7b", + messages=[ + {"role": "user", "content": "你好"}, + ], + stream=True + ): + if hasattr(chunk.choices[0].delta, "content"): + print(chunk.choices[0].delta.content, end="", flush=True) +``` \ No newline at end of file diff --git a/tools/README_EN.md b/tools/README_EN.md index 3105146..8c7e005 100644 --- a/tools/README_EN.md +++ b/tools/README_EN.md @@ -107,3 +107,29 @@ InternLM performance in the GSM8K dataset with and without tools: | -------- | -------------------- | | w/o tool | 34.5 | | w tool | 39.2 | + +# openai_api.py + +`openai_api.py` implements stream deployment with OpenAI APIs which an be used on any applications based on ChatGPT. Below is the command to deploy `internlm`: + +```bash +python openai_api.py +``` + +Then it is able to call the deployed API using the following python code: + +```python +import openai +if __name__ == "__main__": + openai.api_base = "http://localhost:8000/internlm" + openai.api_key = "none" + for chunk in openai.ChatCompletion.create( + model="internlm-chat-7b", + messages=[ + {"role": "user", "content": "Hello!"}, + ], + stream=True + ): + if hasattr(chunk.choices[0].delta, "content"): + print(chunk.choices[0].delta.content, end="", flush=True) +``` diff --git a/tools/openai_api.py b/tools/openai_api.py new file mode 100644 index 0000000..f853329 --- /dev/null +++ b/tools/openai_api.py @@ -0,0 +1,157 @@ +import time +from contextlib import asynccontextmanager +from typing import List, Literal, Optional, Union + +import torch +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from sse_starlette.sse import EventSourceResponse +from transformers import AutoModelForCausalLM, AutoTokenizer + + +@asynccontextmanager +async def lifespan(app: FastAPI): # collects GPU memory + yield + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "owner" + root: Optional[str] = None + parent: Optional[str] = None + permission: Optional[list] = None + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class ChatMessage(BaseModel): + role: Literal["user", "assistant", "system"] + content: str + + +class DeltaMessage(BaseModel): + role: Optional[Literal["user", "assistant", "system"]] = None + content: Optional[str] = None + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + temperature: Optional[float] = None + top_p: Optional[float] = None + max_length: Optional[int] = None + stream: Optional[bool] = False + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Literal["stop", "length"] + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] + + +class ChatCompletionResponse(BaseModel): + model: str + object: Literal["chat.completion", "chat.completion.chunk"] + choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + + +@app.get("/internlm/models", response_model=ModelList) +async def list_models(): + model_card = ModelCard(id="internlm") + return ModelList(data=[model_card]) + + +@app.post("/internlm/chat/completions", response_model=ChatCompletionResponse) +async def create_chat_completion(request: ChatCompletionRequest): + global model, tokenizer + + if request.messages[-1].role != "user": + raise HTTPException(status_code=400, detail="Invalid request") + query = request.messages[-1].content + + prev_messages = request.messages[:-1] + if len(prev_messages) > 0 and prev_messages[0].role == "system": + query = prev_messages.pop(0).content + query + + history = [] + if len(prev_messages) % 2 == 0: + for i in range(0, len(prev_messages), 2): + if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant": + history.append([prev_messages[i].content, prev_messages[i + 1].content]) + + if request.stream: + generate = predict(query, history, request.model) + return EventSourceResponse(generate, media_type="text/event-stream") + + response, _ = model.chat(tokenizer, query, history=history) + choice_data = ChatCompletionResponseChoice( + index=0, message=ChatMessage(role="assistant", content=response), finish_reason="stop" + ) + + return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") + + +async def predict(query: str, history: List[List[str]], model_id: str): + global model, tokenizer + + choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(role="assistant"), finish_reason=None) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + + current_length = 0 + + for new_response, _ in model.stream_chat(tokenizer, query, history): + if len(new_response) == current_length: + continue + + new_text = new_response[current_length:] + + current_length = len(new_response) + + choice_data = ChatCompletionResponseStreamChoice( + index=0, delta=DeltaMessage(content=new_text), finish_reason=None + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + + choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason="stop") + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "[DONE]" + + +if __name__ == "__main__": + model_name = "internlm/internlm-chat-7b" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + model.eval() + + uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) diff --git a/tools/transformers/modeling_internlm.py b/tools/transformers/modeling_internlm.py index da7aaa0..5439ba7 100644 --- a/tools/transformers/modeling_internlm.py +++ b/tools/transformers/modeling_internlm.py @@ -869,7 +869,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel): producer.start() while True: res = response_queue.get() - if res is not None: + if res is None: return yield res