mirror of https://github.com/hpcaitech/ColossalAI
[doc] fix docs about booster api usage (#3898)
parent
ec9bbc0094
commit
c1535ccbba
|
@ -25,11 +25,11 @@ class Booster:
|
|||
Examples:
|
||||
```python
|
||||
colossalai.launch(...)
|
||||
plugin = GeminiPlugin(stage=3, ...)
|
||||
plugin = GeminiPlugin(...)
|
||||
booster = Booster(precision='fp16', plugin=plugin)
|
||||
|
||||
model = GPT2()
|
||||
optimizer = Adam(model.parameters())
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
dataloader = Dataloader(Dataset)
|
||||
lr_scheduler = LinearWarmupScheduler()
|
||||
criterion = GPTLMLoss()
|
||||
|
|
|
@ -195,7 +195,7 @@ def get_data(batch_size, seq_len, vocab_size):
|
|||
Finally, we define a model which uses Gemini + ZeRO DDP and define our training loop, As we pre-train GPT in this example, we just use a simple language model loss:
|
||||
|
||||
```python
|
||||
from torch.optim import Adam
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.zero import ColoInitContext
|
||||
|
@ -211,7 +211,7 @@ def main():
|
|||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
||||
|
||||
torch.manual_seed(123)
|
||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
|
|
|
@ -197,7 +197,7 @@ def get_data(batch_size, seq_len, vocab_size):
|
|||
最后,使用booster注入 Gemini + ZeRO DDP 特性, 并定义训练循环。由于我们在这个例子中对GPT进行预训练,因此只使用了一个简单的语言模型损失函数:
|
||||
|
||||
```python
|
||||
from torch.optim import Adam
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.zero import ColoInitContext
|
||||
|
@ -213,7 +213,7 @@ def main():
|
|||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
||||
|
||||
torch.manual_seed(123)
|
||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
|
|
Loading…
Reference in New Issue