diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 61d912157..4a42e2049 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -25,11 +25,11 @@ class Booster: Examples: ```python colossalai.launch(...) - plugin = GeminiPlugin(stage=3, ...) + plugin = GeminiPlugin(...) booster = Booster(precision='fp16', plugin=plugin) model = GPT2() - optimizer = Adam(model.parameters()) + optimizer = HybridAdam(model.parameters()) dataloader = Dataloader(Dataset) lr_scheduler = LinearWarmupScheduler() criterion = GPTLMLoss() diff --git a/docs/source/en/features/zero_with_chunk.md b/docs/source/en/features/zero_with_chunk.md index d6f6f611a..1b27d64b6 100644 --- a/docs/source/en/features/zero_with_chunk.md +++ b/docs/source/en/features/zero_with_chunk.md @@ -195,7 +195,7 @@ def get_data(batch_size, seq_len, vocab_size): Finally, we define a model which uses Gemini + ZeRO DDP and define our training loop, As we pre-train GPT in this example, we just use a simple language model loss: ```python -from torch.optim import Adam +from colossalai.nn.optimizer import HybridAdam from colossalai.booster import Booster from colossalai.zero import ColoInitContext @@ -211,7 +211,7 @@ def main(): # build criterion criterion = GPTLMLoss() - optimizer = Adam(model.parameters(), lr=0.001) + optimizer = HybridAdam(model.parameters(), lr=0.001) torch.manual_seed(123) default_pg = ProcessGroup(tp_degree=args.tp_degree) diff --git a/docs/source/zh-Hans/features/zero_with_chunk.md b/docs/source/zh-Hans/features/zero_with_chunk.md index 9030464dd..9fe5601bb 100644 --- a/docs/source/zh-Hans/features/zero_with_chunk.md +++ b/docs/source/zh-Hans/features/zero_with_chunk.md @@ -197,7 +197,7 @@ def get_data(batch_size, seq_len, vocab_size): 最后,使用booster注入 Gemini + ZeRO DDP 特性, 并定义训练循环。由于我们在这个例子中对GPT进行预训练,因此只使用了一个简单的语言模型损失函数: ```python -from torch.optim import Adam +from colossalai.nn.optimizer import HybridAdam from colossalai.booster import Booster from colossalai.zero import ColoInitContext @@ -213,7 +213,7 @@ def main(): # build criterion criterion = GPTLMLoss() - optimizer = Adam(model.parameters(), lr=0.001) + optimizer = HybridAdam(model.parameters(), lr=0.001) torch.manual_seed(123) default_pg = ProcessGroup(tp_degree=args.tp_degree)