mirror of https://github.com/hpcaitech/ColossalAI
fixed the example docstring for booster (#3795)
parent
788e07dbc5
commit
f5c425c898
|
@ -23,27 +23,28 @@ class Booster:
|
|||
training with different precision, accelerator, and plugin.
|
||||
|
||||
Examples:
|
||||
>>> colossalai.launch(...)
|
||||
>>> plugin = GeminiPlugin(stage=3, ...)
|
||||
>>> booster = Booster(precision='fp16', plugin=plugin)
|
||||
>>>
|
||||
>>> model = GPT2()
|
||||
>>> optimizer = Adam(model.parameters())
|
||||
>>> dataloader = Dataloader(Dataset)
|
||||
>>> lr_scheduler = LinearWarmupScheduler()
|
||||
>>> criterion = GPTLMLoss()
|
||||
>>>
|
||||
>>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
|
||||
>>>
|
||||
>>> for epoch in range(max_epochs):
|
||||
>>> for input_ids, attention_mask in dataloader:
|
||||
>>> outputs = model(input_ids, attention_mask)
|
||||
>>> loss = criterion(outputs.logits, input_ids)
|
||||
>>> booster.backward(loss, optimizer)
|
||||
>>> optimizer.step()
|
||||
>>> lr_scheduler.step()
|
||||
>>> optimizer.zero_grad()
|
||||
```python
|
||||
colossalai.launch(...)
|
||||
plugin = GeminiPlugin(stage=3, ...)
|
||||
booster = Booster(precision='fp16', plugin=plugin)
|
||||
|
||||
model = GPT2()
|
||||
optimizer = Adam(model.parameters())
|
||||
dataloader = Dataloader(Dataset)
|
||||
lr_scheduler = LinearWarmupScheduler()
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
for input_ids, attention_mask in dataloader:
|
||||
outputs = model(input_ids, attention_mask)
|
||||
loss = criterion(outputs.logits, input_ids)
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
Args:
|
||||
device (str or torch.device): The device to run the training. Default: 'cuda'.
|
||||
|
|
Loading…
Reference in New Issue