ColossalAI/examples/tutorial/new_api/glue_bert
Hongxin Liu 7f8b16635b
[misc] refactor launch API and tensor constructor (#5666)
* [misc] remove config arg from initialize

* [misc] remove old tensor contrusctor

* [plugin] add npu support for ddp

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [devops] fix doc test ci

* [test] fix test launch

* [doc] update launch doc

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-29 10:40:11 +08:00
..
README.md [example] add train resnet/vit with booster example (#3694) 2023-05-08 10:42:30 +08:00
data.py [misc] update pre-commit and run all files (#4752) 2023-09-19 14:20:26 +08:00
finetune.py [misc] refactor launch API and tensor constructor (#5666) 2024-04-29 10:40:11 +08:00
requirements.txt [example] add train resnet/vit with booster example (#3694) 2023-05-08 10:42:30 +08:00
test_ci.sh [fix] fix weekly runing example (#4787) 2023-09-25 16:19:33 +08:00

README.md

Finetune BERT on GLUE

🚀 Quick Start

This example provides a training script, which provides an example of finetuning BERT on GLUE dataset.

  • Training Arguments
    • -t, --task: GLUE task to run. Defaults to mrpc.
    • -p, --plugin: Plugin to use. Choices: torch_ddp, torch_ddp_fp16, gemini, low_level_zero. Defaults to torch_ddp.
    • --target_f1: Target f1 score. Raise exception if not reached. Defaults to None.

Install requirements

pip install -r requirements.txt

Train

# train with torch DDP with fp32
colossalai run --nproc_per_node 4 finetune.py

# train with torch DDP with mixed precision training
colossalai run --nproc_per_node 4 finetune.py -p torch_ddp_fp16

# train with gemini
colossalai run --nproc_per_node 4 finetune.py -p gemini

# train with low level zero
colossalai run --nproc_per_node 4 finetune.py -p low_level_zero

Expected F1-score will be:

Model Single-GPU Baseline FP32 Booster DDP with FP32 Booster DDP with FP16 Booster Gemini Booster Low Level Zero
bert-base-uncased 0.86 0.88 0.87 0.88 0.89