From f5c425c89874f2500600be71b3c9aadad2da822f Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 22 May 2023 18:10:06 +0800 Subject: [PATCH] fixed the example docstring for booster (#3795) --- colossalai/booster/booster.py | 41 ++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 4055e55df..be9c1c9dc 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -23,27 +23,28 @@ class Booster: training with different precision, accelerator, and plugin. Examples: - >>> colossalai.launch(...) - >>> plugin = GeminiPlugin(stage=3, ...) - >>> booster = Booster(precision='fp16', plugin=plugin) - >>> - >>> model = GPT2() - >>> optimizer = Adam(model.parameters()) - >>> dataloader = Dataloader(Dataset) - >>> lr_scheduler = LinearWarmupScheduler() - >>> criterion = GPTLMLoss() - >>> - >>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader) - >>> - >>> for epoch in range(max_epochs): - >>> for input_ids, attention_mask in dataloader: - >>> outputs = model(input_ids, attention_mask) - >>> loss = criterion(outputs.logits, input_ids) - >>> booster.backward(loss, optimizer) - >>> optimizer.step() - >>> lr_scheduler.step() - >>> optimizer.zero_grad() + ```python + colossalai.launch(...) + plugin = GeminiPlugin(stage=3, ...) + booster = Booster(precision='fp16', plugin=plugin) + model = GPT2() + optimizer = Adam(model.parameters()) + dataloader = Dataloader(Dataset) + lr_scheduler = LinearWarmupScheduler() + criterion = GPTLMLoss() + + model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader) + + for epoch in range(max_epochs): + for input_ids, attention_mask in dataloader: + outputs = model(input_ids, attention_mask) + loss = criterion(outputs.logits, input_ids) + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + ``` Args: device (str or torch.device): The device to run the training. Default: 'cuda'.