[Dev] Pull Main (#139)

* fix/fix_submodule_err (#61)

* fix/fix_submodule_err

---------

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>

* fix issue templates (#65)

* fix(tokenizer): refactor tokenizer and update usage in readme (#51)

* update tokenizer example

* fix(readme, requirements): fix typo at Chinese readme and select a lower version of transformers (#73)

* fix a typo in readme

* in order to find InternLMTokenizer, select a lower version of Transformers

---------

Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com>

* [Doc] Add wechat and discord link in readme (#78)

* Doc:add wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* Doc:update wechat and discord link

* [Docs]: add Japanese README (#43)

* Add Japanese README

* Update README-ja-JP.md

replace message

* Update README-ja-JP.md

* add repetition_penalty in GenerationConfig in web_demo.py (#48)

Co-authored-by: YWMditto <862779238@qq.com>

* use fp16 in instruction (#80)

* [Enchancement] add more options for issue template (#77)

* [Enchancement] add more options for issue template

* update qustion icon

* fix link

* Use tempfile for convert2hf.py (#23)

Fix https://github.com/InternLM/InternLM/issues/50

* delete torch_dtype of README's example code (#100)

* set the value of repetition_penalty to 1.0 to avoid random outputs (#99)

* Update web_demo.py (#97)

Remove meaningless log.

* [Fix]Fix wrong string cutoff in the script for sft text tokenizing (#106)

* docs(install.md): update dependency package transformers version to >= 4.28.0 (#124)

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>

* docs(LICENSE): add license (#125)

* add license of colossalai and flash-attn

* fix lint

* modify the name

* fix AutoModel map in convert2hf.py (#116)

* variables are not printly as expect (#114)

* feat(solver): fix code to adapt to torch2.0 and provide docker images (#128)

* feat(solver): fix code to adapt to torch2.0

* docs(install.md): publish internlm environment image

* docs(install.md): update dependency packages version

* docs(install.md): update default image

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>

* add demo test (#132)

Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>

* fix web_demo cache accelerate (#133)

* fix(hybrid_zero_optim.py): delete math import

* Update embedding.py

---------

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: Changjiang GOU <gouchangjiang@gmail.com>
Co-authored-by: gouhchangjiang <gouhchangjiang@gmail.com>
Co-authored-by: vansin <msnode@163.com>
Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com>
Co-authored-by: YWMditto <46778265+YWMditto@users.noreply.github.com>
Co-authored-by: YWMditto <862779238@qq.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>
Co-authored-by: x54-729 <45304952+x54-729@users.noreply.github.com>
Co-authored-by: Shuo Zhang <zhangshuolove@live.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
Co-authored-by: huangting4201 <1538303371@qq.com>
Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>
Co-authored-by: ytxiong <45058324+yingtongxiong@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: kkscilife <126147887+kkscilife@users.noreply.github.com>
Co-authored-by: qa-caif-cicd <qa-caif-cicd@pjlab.org.cn>
Co-authored-by: hw <45089338+MorningForest@users.noreply.github.com>
pull/147/head
Sun Peng 2023-07-27 10:20:21 +08:00 committed by GitHub
parent ad10b8e03f
commit fcc3534509
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 428 additions and 16 deletions

68
.github/workflows/demo_in_readme.yaml vendored Normal file
View File

@ -0,0 +1,68 @@
name: demo-in-readme
on:
pull_request:
branches:
- "main"
- "develop"
paths-ignore:
- "docs/**"
- "**.md"
jobs:
dataset-preparation:
runs-on: [lmtest]
steps:
- uses: actions/checkout@v3
- name: raw-chinese-data
run: |
source activate internlm-env-test
sh ./ci_scripts/data/tokenizer_chinese.sh
- name: alpaca-data
run: |
source activate internlm-env-test
sh ./ci_scripts/data/tokenizer_alpaca.sh
train:
runs-on: [lmtest]
steps:
- uses: actions/checkout@v3
- name: slurm-train
run: |
source activate internlm-env-test
sh ./ci_scripts/train/slurm_train.sh
rm -rf $GITHUB_WORKSPACE/llm_ckpts
- name: torchrun-train
run: |
source activate internlm-env-test
sh ./ci_scripts/train/torchrun.sh
rm -rf $GITHUB_WORKSPACE/llm_ckpts
convert-model-then-load:
runs-on: [lmtest]
steps:
- uses: actions/checkout@v3
- name: convert-model-then-load
run: |
source activate internlm-env-test
export PYTHONPATH=$PWD:$PYTHONPATH
sh ./ci_scripts/model/convert_to_hf.sh
cd ./hf_ckpt
srun -p llm2 python ../ci_scripts/model/loaded_as_transformer.py
cd ..
rm -rf $GITHUB_WORKSPACE/hf_ckpt
load-chat-model-in-hf:
runs-on: [lmtest]
steps:
- uses: actions/checkout@v3
- name: chat-model-in-hf
run: |
source activate internlm-env-test
srun -p llm2 python ./ci_scripts/model/demo_load_7B_chat_model.py

48
LICENSE
View File

@ -199,3 +199,51 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
## Some of InternLM's code is derived from others projects, which is subject to the following copyright notice:
Copyright 2021- HPC-AI Technology Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---------------- LICENSE FOR Flash Attention ----------------
BSD 3-Clause License
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,14 @@
#!/bin/bash
export exit_code=0
function if_exist() {
ls -l $file_path
exit_code_now=$?
exit_code=$(($exit_code + $exit_code_now))
}
function num_files() {
file_num=$(ls -l $file_dir |wc -l)
echo "there are $file_num files in $file_dir"
}

View File

@ -0,0 +1,22 @@
#!/bin/bash
rm -rf /mnt/petrelfs/qa-caif-cicd/data/lm_data/alpaca_data/result/*
python tools/alpaca_tokenizer.py /mnt/petrelfs/qa-caif-cicd/data/lm_data/alpaca_data/alpaca_data.json /mnt/petrelfs/qa-caif-cicd/data/lm_data/alpaca_data/result tools/V7_sft.model --split_ratio 0.1
file_one="/mnt/petrelfs/qa-caif-cicd/data/lm_data/alpaca_data/result/train/en/dataset.bin"
file_two="/mnt/petrelfs/qa-caif-cicd/data/lm_data/alpaca_data/result/train/en/dataset.bin.meta"
file_three="/mnt/petrelfs/qa-caif-cicd/data/lm_data/alpaca_data/result/valid/en/dataset.bin"
file_four="/mnt/petrelfs/qa-caif-cicd/data/lm_data/alpaca_data/result/valid/en/dataset.bin.meta"
file_list=($file_one $file_two $file_three $file_four)
source ./ci_scripts/common/basic_func.sh
for file_path in ${file_list[@]};
do
if_exist $file_path
done
if [ $exit_code -ne 0 ]
then
exit 1
fi

View File

@ -0,0 +1,19 @@
#!/bin/bash
rm -rf /mnt/petrelfs/qa-caif-cicd/data/lm_data/cn_data/result.*
srun -p llm2 python tools/tokenizer.py --text_input_path /mnt/petrelfs/qa-caif-cicd/data/lm_data/cn_data/raw_data.txt --bin_output_path /mnt/petrelfs/qa-caif-cicd/data/lm_data/cn_data/result.bin
file_one="/mnt/petrelfs/qa-caif-cicd/data/lm_data/cn_data/result.bin"
file_two="/mnt/petrelfs/qa-caif-cicd/data/lm_data/cn_data/result.bin.meta"
file_list=($file_one $file_two)
source ./ci_scripts/common/basic_func.sh
for file_path in ${file_list[@]};
do
if_exist $file_path
done
if [ $exit_code -ne 0 ]
then
exit 1
fi

View File

@ -0,0 +1,33 @@
#!/bin/bash
rm -rf ./hf_ckpt/*
python ./tools/transformers/convert2hf.py --src_folder /mnt/petrelfs/qa-caif-cicd/data/lm_data/alpaca_data/llm_ckpts/20 --tgt_folder hf_ckpt/ --tokenizer ./tools/V7_sft.model
#assert exists model
file_one="$GITHUB_WORKSPACE/hf_ckpt/tokenizer.model"
file_two="$GITHUB_WORKSPACE/hf_ckpt/config.json"
file_three="$GITHUB_WORKSPACE/hf_ckpt/modeling_internlm.py"
file_list=($file_one $file_two $file_three)
file_dir="$GITHUB_WORKSPACE/hf_ckpt/*"
source ./ci_scripts/common/basic_func.sh
for file_path in ${file_list[@]};
do
if_exist $file_path
done
num_files ${file_dir}
if [ $file_num -ne 9 ]
then
echo "The num of files is not right"
ls -l $file_dir
exit_code=$(($exit_code + 1))
fi
if [ $exit_code -ne 0 ]
then
exit 1
fi

View File

@ -0,0 +1,12 @@
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).cuda()
model = model.eval()
response, history = model.chat(tokenizer, "你好", history=[])
print(response)
assert len(response) != 0
response, history = model.chat(tokenizer, "请提供三个管理时间的建议。", history=history)
print(response)
assert len(response) != 0

View File

@ -0,0 +1,7 @@
from transformers import AutoModel
model = AutoModel.from_pretrained("../hf_ckpt/", trust_remote_code=True).cuda()
print(model)
assert model.config.hidden_size == 2048
assert model.config.num_attention_heads == 16
assert model.config.num_hidden_layers == 16

View File

@ -0,0 +1,130 @@
JOB_NAME = "7b_train"
SEQ_LEN = 1024
HIDDEN_SIZE = 2048
NUM_ATTENTION_HEAD = 16
MLP_RATIO = 8 / 3
NUM_LAYER = 16
VOCAB_SIZE = 103168
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
# oss: 'boto3:s3://model_weights/XXX'
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
#SAVE_CKPT_FOLDER = "local:llm_ckpts"
SAVE_CKPT_FOLDER = "local:llm_ckpts"
#LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
ckpt = dict(
# Path to save training ckpt.
save_ckpt_folder=SAVE_CKPT_FOLDER,
# Path to continue training ckpt (load model weights and scheduler/context states).
# load_ckpt_folder=LOAD_CKPT_FOLDER,
# Path to initialize with given model weights.
# load_model_only_folder=MODEL_ONLY_FOLDER,
checkpoint_every=20,
# Wheter to load optimizer states when continuing training.
load_optimizer=True,
)
TRAIN_FOLDER = "/mnt/petrelfs/qa-caif-cicd/data/lm_data/alpaca_data/train/en"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=2,
pack_sample_into_one=False,
total_steps=20,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
# train_folder=TRAIN_FOLDER,
)
grad_scaler = dict(
fp16=dict(
# the initial loss scale, defaults to 2**16
initial_scale=2**16,
# the minimum loss scale, defaults to None
min_scale=1,
# the number of steps to increase loss scale when no overflow occurs
growth_interval=1000,
),
# the multiplication factor for increasing loss scale, defaults to 2
growth_factor=2,
# the multiplication factor for decreasing loss scale, defaults to 0.5
backoff_factor=0.5,
# the maximum loss scale, defaults to None
max_scale=2**24,
# the number of overflows before decreasing loss scale, defaults to 2
hysteresis=2,
)
hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
zero_overlap_communication=True,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
clip_grad_norm=1.0,
)
loss = dict(
label_smoothing=0,
)
adam = dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
)
lr_scheduler = dict(
total_steps=data["total_steps"],
init_steps=0, # optimizer_warmup_step
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
)
beta2_scheduler = dict(
init_beta2=adam["adam_beta2"],
c=adam["adam_beta2_c"],
cur_iter=-1,
)
model = dict(
checkpoint=False,
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
)
"""
zero1 parallel:
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
so parameters will be divided within the range of dp.
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel: pipeline parallel size, only 1 is accepted currently.
tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently.
"""
parallel = dict(
zero1=8,
)
cudnn_deterministic = False
cudnn_benchmark = False

View File

@ -0,0 +1,20 @@
#!/bin/bash
rm -rf $GITHUB_WORKSPACE/llm_ckpts/20
srun -p llm2 --quotatype=spot -n 8 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./ci_scripts/train/ci_7B_sft.py
file_dir="$GITHUB_WORKSPACE/llm_ckpts/20/*.pt"
source ./ci_scripts/common/basic_func.sh
num_files ${file_dir}
if [ $file_num -ne 21 ]
then
echo "The num of files is not right"
ls -l $file_dir
rm -rf $GITHUB_WORKSPACE/llm_ckpts
exit 1
fi

View File

@ -0,0 +1,17 @@
#!/bin/bash
rm -rf $GITHUB_WORKSPACE/llm_ckpts/20
srun -p llm2 -N 1 torchrun --nnodes=1 --nproc_per_node=8 --master_port=29501 train.py --config ./ci_scripts/train/ci_7B_sft.py --launcher "torch"
file_dir="$GITHUB_WORKSPACE/llm_ckpts/20/*.pt"
source ./ci_scripts/common/basic_func.sh
num_files ${file_dir}
if [ $file_num -ne 21 ]
then
echo "The num of files is not right"
ls -l $file_dir
rm -rf $GITHUB_WORKSPACE/llm_ckpts
exit 1
fi

View File

@ -5,10 +5,10 @@ The required packages and corresponding version are shown as follows:
- Python == 3.10
- GCC == 10.2.0
- MPFR == 4.1.0
- CUDA == 11.7
- Pytorch == 1.13.1+cu117
- Transformers >= 4.25.1
- Flash-Attention == v1.0.5
- CUDA >= 11.7
- Pytorch >= 1.13.1
- Transformers >= 4.28.0
- Flash-Attention >= v1.0.5
- Apex == 23.05
- GPU with Ampere or Hopper architecture (such as H100, A100)
- Linux OS
@ -57,3 +57,14 @@ cd ./third_party/apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ../../
```
### Environment Image
Users can obtain an image with the InternLM runtime environment installed from https://hub.docker.com/r/sunpengsdu/internlm. The commands for pulling the image and starting the container are as follows:
```bash
# pull image
docker pull sunpengsdu/internlm:torch1.13-cuda11.7-flashatten1.0.5-centos
# start container
docker run --gpus all -d -it --shm-size=2gb --name myinternlm sunpengsdu/internlm:torch1.13-cuda11.7-flashatten1.0.5-centos
docker exec -it myinternlm bash
```

View File

@ -5,10 +5,10 @@
- Python == 3.10
- GCC == 10.2.0
- MPFR == 4.1.0
- CUDA == 11.7
- Pytorch == 1.13.1+cu117
- Transformers >= 4.25.1
- Flash-Attention == v1.0.5
- CUDA >= 11.7
- Pytorch >= 1.13.1
- Transformers >= 4.28.0
- Flash-Attention >= v1.0.5
- Apex == 23.05
- Ampere或者Hopper架构的GPU (例如H100, A100)
- Linux OS
@ -57,3 +57,13 @@ cd ./third_party/apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ../../
```
### 环境镜像
用户可以从 https://hub.docker.com/r/sunpengsdu/internlm 获取安装了 InternLM 运行环境的镜像,拉取镜像及启动容器的命令如下:
```bash
# 拉取镜像
docker pull sunpengsdu/internlm:torch1.13-cuda11.7-flashatten1.0.5-centos
# 启动容器
docker run --gpus all -d -it --shm-size=2gb --name myinternlm sunpengsdu/internlm:torch1.13-cuda11.7-flashatten1.0.5-centos
docker exec -it myinternlm bash
```

View File

@ -175,7 +175,7 @@ class RotaryEmbedding(torch.nn.Module):
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
def forward(self, qkv: torch.Tensor, **kwargs):
if kwargs.get("indexes", None) is not None:
return self._forward(qkv, kwargs.pop("indexes"))
@ -183,7 +183,7 @@ class RotaryEmbedding(torch.nn.Module):
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
else:
return self._eval_forward(qkv)
def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(qkv, indexes)
if self.scale is None:

View File

@ -27,7 +27,7 @@ class WarmupScheduler(_LRScheduler):
def state_dict(self):
state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
if isinstance(state_dict["after_scheduler"], _LRScheduler):
if isinstance(state_dict["after_scheduler"], (_LRScheduler, _CosineAnnealingLR)):
state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
del state_dict["after_scheduler"]
@ -40,7 +40,7 @@ class WarmupScheduler(_LRScheduler):
for key in list(self.__dict__.keys()):
if key in state_dict:
self.__dict__[key] = state_dict[key]
if isinstance(self.after_scheduler, _LRScheduler):
if isinstance(self.after_scheduler, (_LRScheduler, _CosineAnnealingLR)):
assert type(self.after_scheduler).__name__ == state_dict["after_scheduler_type"]
# state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
self.after_scheduler.load_state_dict(state_dict["after_scheduler_dict"])

View File

@ -160,5 +160,5 @@ if __name__ == "__main__":
train_tokens, valid_tokens, train_samples, valid_samples = dump_bin_meta_bin(
samples, args.output_path, args.split_ratio
)
print(f"number of train dataset: {train_samples}, " "number of train dataset token: {train_tokens}")
print(f"number of validation dataset: {valid_samples}, " "number of validation dataset token: {valid_tokens}")
print(f"number of train dataset: {train_samples}, number of train dataset token: {train_tokens}")
print(f"number of validation dataset: {valid_samples}, number of validation dataset token: {valid_tokens}")

View File

@ -167,7 +167,7 @@ if __name__ == "__main__":
# TODO There should be a better way to add this.
with open(os.path.join(target_folder, "config.json")) as fp:
config_dict = json.load(fp)
config_dict["auto_map"]["AutoModel"] = "modeling_internlm.InternLMModel"
config_dict["auto_map"]["AutoModel"] = "modeling_internlm.InternLMForCausalLM"
with open(os.path.join(target_folder, "config.json"), "w") as fp:
json.dump(config_dict, fp, indent=2)

View File

@ -199,7 +199,7 @@ def combine_history(prompt):
def main():
torch.cuda.empty_cache()
#torch.cuda.empty_cache()
print("load model begin.")
model, tokenizer = load_model()
print("load model end.")
@ -237,6 +237,7 @@ def main():
message_placeholder.markdown(cur_response)
# Add robot response to chat history
st.session_state.messages.append({"role": "robot", "content": cur_response, "avatar": robot_avator})
torch.cuda.empty_cache()
if __name__ == "__main__":