mirror of https://github.com/hpcaitech/ColossalAI
fixed the example docstring for booster (#3795)
parent
788e07dbc5
commit
f5c425c898
|
@ -23,27 +23,28 @@ class Booster:
|
||||||
training with different precision, accelerator, and plugin.
|
training with different precision, accelerator, and plugin.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> colossalai.launch(...)
|
```python
|
||||||
>>> plugin = GeminiPlugin(stage=3, ...)
|
colossalai.launch(...)
|
||||||
>>> booster = Booster(precision='fp16', plugin=plugin)
|
plugin = GeminiPlugin(stage=3, ...)
|
||||||
>>>
|
booster = Booster(precision='fp16', plugin=plugin)
|
||||||
>>> model = GPT2()
|
|
||||||
>>> optimizer = Adam(model.parameters())
|
|
||||||
>>> dataloader = Dataloader(Dataset)
|
|
||||||
>>> lr_scheduler = LinearWarmupScheduler()
|
|
||||||
>>> criterion = GPTLMLoss()
|
|
||||||
>>>
|
|
||||||
>>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
|
|
||||||
>>>
|
|
||||||
>>> for epoch in range(max_epochs):
|
|
||||||
>>> for input_ids, attention_mask in dataloader:
|
|
||||||
>>> outputs = model(input_ids, attention_mask)
|
|
||||||
>>> loss = criterion(outputs.logits, input_ids)
|
|
||||||
>>> booster.backward(loss, optimizer)
|
|
||||||
>>> optimizer.step()
|
|
||||||
>>> lr_scheduler.step()
|
|
||||||
>>> optimizer.zero_grad()
|
|
||||||
|
|
||||||
|
model = GPT2()
|
||||||
|
optimizer = Adam(model.parameters())
|
||||||
|
dataloader = Dataloader(Dataset)
|
||||||
|
lr_scheduler = LinearWarmupScheduler()
|
||||||
|
criterion = GPTLMLoss()
|
||||||
|
|
||||||
|
model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
|
||||||
|
|
||||||
|
for epoch in range(max_epochs):
|
||||||
|
for input_ids, attention_mask in dataloader:
|
||||||
|
outputs = model(input_ids, attention_mask)
|
||||||
|
loss = criterion(outputs.logits, input_ids)
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device (str or torch.device): The device to run the training. Default: 'cuda'.
|
device (str or torch.device): The device to run the training. Default: 'cuda'.
|
||||||
|
|
Loading…
Reference in New Issue