Merge branch 'main' into feature/fp8_comm

pull/6012/head
Hongxin Liu 2024-08-22 15:14:38 +08:00 committed by GitHub
commit caab4a307f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 242 additions and 116 deletions

View File

@ -31,18 +31,18 @@ jobs:
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v -e .
BUILD_EXT=1 pip install --no-cache-dir -v -e .
- name: Install ChatGPT
run: |
cd applications/ColossalChat
pip install -v .
pip install --no-cache-dir -v .
export BUILD_EXT=1
pip install -r examples/requirements.txt
pip install --no-cache-dir -r examples/requirements.txt
- name: Install Transformers
run: |
pip install transformers==4.36.2
pip install --no-cache-dir transformers==4.36.2
- name: Execute Examples
run: |

View File

@ -161,3 +161,9 @@ applications/ColossalChat/sft_data
applications/ColossalChat/prompt_data
applications/ColossalChat/preference_data
applications/ColossalChat/temp
# Testing data
/kto_data/
/preference_data/
/prompt_data/
/sft_data/

View File

@ -16,7 +16,7 @@ from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience
from torch.optim import Optimizer
from colossalai.booster import Booster
from colossalai.booster import Booster, Plugin
from .utils import is_rank_0
@ -38,6 +38,7 @@ class SLTrainer(ABC):
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
plugin: Plugin,
start_epoch: int = 0,
) -> None:
super().__init__()
@ -45,6 +46,7 @@ class SLTrainer(ABC):
self.max_epochs = max_epochs
self.model = model
self.optimizer = optimizer
self.plugin = plugin
self.start_epoch = start_epoch
@abstractmethod

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from tqdm import trange
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
@ -50,6 +50,7 @@ class DPOTrainer(SLTrainer):
ref_model: Any,
booster: Booster,
actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
@ -63,7 +64,9 @@ class DPOTrainer(SLTrainer):
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
super().__init__(
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
)
self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer

View File

@ -17,7 +17,7 @@ from torch.utils.data import DataLoader
from tqdm import trange
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
@ -53,6 +53,7 @@ class KTOTrainer(SLTrainer):
ref_model: Any,
booster: Booster,
actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
@ -66,7 +67,9 @@ class KTOTrainer(SLTrainer):
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
super().__init__(
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
)
self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from tqdm import trange
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
@ -48,6 +48,7 @@ class ORPOTrainer(SLTrainer):
actor: Any,
booster: Booster,
actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
@ -59,7 +60,9 @@ class ORPOTrainer(SLTrainer):
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
super().__init__(
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
)
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.odds_ratio_loss_fn = OddsRatioLoss()

View File

@ -15,7 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
@ -48,6 +48,7 @@ class RewardModelTrainer(SLTrainer):
model: Any,
booster: Booster,
optimizer: Optimizer,
plugin: Plugin,
lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
loss_fn: Optional[Callable] = None,
@ -59,7 +60,9 @@ class RewardModelTrainer(SLTrainer):
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch)
super().__init__(
booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch
)
self.actor_scheduler = lr_scheduler
self.tokenizer = tokenizer
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)

View File

@ -6,14 +6,16 @@ import os
from typing import Optional
import torch
import torch.distributed as dist
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import trange
from tqdm import tqdm, trange
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
from colossalai.cluster import DistCoordinator
from .base import SLTrainer
@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer):
optim: Optimizer,
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
plugin: Plugin = None,
accumulation_steps: int = 8,
apply_loss_mask: bool = True,
start_epoch=0,
@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer):
save_dir: str = None,
coordinator: Optional[DistCoordinator] = None,
) -> None:
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)
self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler
@ -94,60 +97,85 @@ class SFTTrainer(SLTrainer):
def _train(self, epoch: int):
self.model.train()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
data_iter = iter(self.train_dataloader)
step_bar = tqdm(
range(len(self.train_dataloader)),
desc="Step",
disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
loss = outputs.loss
for step in step_bar:
outputs = self.booster.execute_pipeline(
data_iter,
self.model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.booster.plugin.stage_manager.is_last_stage():
global_loss = all_reduce_mean(loss, self.plugin)
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix({"train/loss": global_loss.item()})
loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
else:
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss = outputs.loss
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
self.booster.backward(loss=loss, optimizer=self.optimizer)
# Save checkpoint
if (
self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
# Save checkpoint
if (
self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
@ -157,27 +185,64 @@ class SFTTrainer(SLTrainer):
self.accumulative_meter.reset()
self.model.eval()
with torch.no_grad():
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
data_iter = iter(self.eval_dataloader)
step_bar = tqdm(
range(len(self.eval_dataloader)),
desc="Step",
disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update()
loss_mean = self.accumulative_meter.get("loss")
msg = "Evaluation Result:\n"
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()
for step in step_bar:
outputs = self.booster.execute_pipeline(
data_iter,
self.model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
global_loss = all_reduce_mean(loss, self.plugin)
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix({"eval/loss": global_loss.item()})
self.accumulative_meter.add("loss", global_loss.item())
if dist.get_rank() == dist.get_world_size() - 1:
loss_mean = self.accumulative_meter.get("loss")
msg = "Evaluation Result:\n"
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
print(msg)
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()
else:
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update()
loss_mean = self.accumulative_meter.get("loss")
msg = "Evaluation Result:\n"
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()

View File

@ -9,6 +9,8 @@ import torch.distributed as dist
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from colossalai.booster import Plugin
class CycledDataLoader:
"""
@ -85,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any:
return tree_map(_to, x)
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
"""
Perform all-reduce operation on the given tensor and compute the mean across all processes.
@ -95,8 +97,13 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: The reduced tensor with mean computed across all processes.
"""
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
# All reduce mean across DP group
if plugin is not None:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
tensor.div_(plugin.dp_size)
else:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor

View File

@ -267,6 +267,7 @@ def train(args):
ref_model=ref_model,
booster=booster,
actor_optim=optim,
plugin=plugin,
actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer,
max_epochs=args.max_epochs,

View File

@ -286,6 +286,7 @@ def train(args):
ref_model=ref_model,
booster=booster,
actor_optim=optim,
plugin=plugin,
actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer,
max_epochs=args.max_epochs,

View File

@ -250,6 +250,7 @@ def train(args):
actor=model,
booster=booster,
actor_optim=optim,
plugin=plugin,
actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer,
max_epochs=args.max_epochs,

View File

@ -262,6 +262,7 @@ def train(args):
model,
booster,
optim,
plugin,
lr_scheduler,
tokenizer,
loss_fn=loss_fn,

View File

@ -114,7 +114,7 @@ def train(args):
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.batch_size,
microbatch_size=args.microbatch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@ -269,6 +269,7 @@ def train(args):
model=model,
booster=booster,
optim=optim,
plugin=plugin,
lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
@ -344,6 +345,7 @@ if __name__ == "__main__":
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument("--microbatch_size", type=int, default=1)
args = parser.parse_args()
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)

View File

@ -61,7 +61,7 @@ def test_overfit():
_, predicted = torch.max(outputs.data, 1)
total = labels.size(0)
correct = (predicted == Y).sum().item()
assert (correct / total > 0.95, "The model has not overfitted to the synthesized dataset")
assert correct / total > 0.95
assert (weight_to_compare - model.fc1.weight).sum() < 0.01

View File

@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 2
set_n_least_used_CUDA_VISIBLE_DEVICES 4
set -xu
@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
MODELS_DIR=$TEMP_DIR/models_config
# Skip those tests due to CI tests timeout
MODELS=('llama')
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'pp' 'tp_pp')
PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu')
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json"
@ -91,7 +91,7 @@ SKIPPED_TESTS=(
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
)
skip_eval=false
GRAD_CKPTS=('--grad_checkpoint')
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
@ -129,15 +129,18 @@ for lora_rank in ${LORA_RANK[@]}; do
plugin='3d'
fi
if [[ $plugin == "tp_pp" ]]; then
echo "Here"
tp='2'
bs='8'
pp='2'
plugin='3d'
skip_eval=true
fi
if [[ $plugin == "pp" ]]; then
bs='8'
pp='2'
plugin='3d'
skip_eval=true
fi
if [[ $plugin == "sp_split_gather" ]]; then
enable_sequence_parallelism='--enable_sequence_parallelism'
@ -175,28 +178,53 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_path $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
$lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
--accumulation_steps $grad_accu \
--tp $tp \
--pp $pp \
--zero_stage $zero_stage \
--sp $sp \
--sp_mode $sp_mode \
$enable_sequence_parallelism \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--use_flash_attn
if [[ $skip_eval ]]; then
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--save_path $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
$lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
--accumulation_steps $grad_accu \
--tp $tp \
--pp $pp \
--zero_stage $zero_stage \
--sp $sp \
--sp_mode $sp_mode \
$enable_sequence_parallelism \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--use_flash_attn
else
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_path $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
$lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
--accumulation_steps $grad_accu \
--tp $tp \
--pp $pp \
--zero_stage $zero_stage \
--sp $sp \
--sp_mode $sp_mode \
$enable_sequence_parallelism \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--use_flash_attn
fi
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*

View File

@ -349,7 +349,7 @@ class LowLevelZeroPlugin(DPPluginBase):
verbose: bool = False,
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
use_fp8: bool = False
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"