mirror of https://github.com/hpcaitech/ColossalAI
[doc] update nvme offload documents. (#3850)
parent
ae959a72a5
commit
b0474878bf
|
@ -78,8 +78,9 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
|||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
```
|
||||
|
||||
Then we define a loss function:
|
||||
|
@ -192,17 +193,23 @@ def train_gemini_cpu(nvme_offload_fraction: float = 0.0):
|
|||
optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction)
|
||||
print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B')
|
||||
|
||||
gemini_config = dict(strict_ddp_mode=True, device=torch.cuda.current_device(),
|
||||
placement_policy='cpu', pin_memory=True, hidden_dim=config.n_embd)
|
||||
model = zero_model_wrapper(model, zero_stage=3, gemini_config=gemini_config)
|
||||
optimizer = zero_optim_wrapper(model, optimizer, initial_scale=2**5)
|
||||
plugin = GeminiPlugin(
|
||||
strict_ddp_mode=True,
|
||||
device=torch.cuda.current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
hidden_dim=config.n_embd,
|
||||
initial_scale=2**5
|
||||
)
|
||||
booster = Booster(plugin)
|
||||
model, optimizer, criterion, _* = booster.boost(model, optimizer, criterion)
|
||||
|
||||
start = time.time()
|
||||
for step in range(3):
|
||||
data = get_data(4, 128, config.vocab_size)
|
||||
outputs = model(**data)
|
||||
loss = criterion(outputs.logits, data['input_ids'])
|
||||
optimizer.backward(loss)
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
print(f'[{step}] loss: {loss.item():.3f}')
|
||||
|
|
|
@ -55,7 +55,6 @@ optimizer = HybridAdam(model.parameters(), lr=1e-3, nvme_offload_fraction=1.0, n
|
|||
|
||||
## Examples
|
||||
|
||||
Let's start from two simple examples -- training GPT with different methods. These examples relies on `transformers`.
|
||||
首先让我们从两个简单的例子开始 -- 用不同的方法训练 GPT。这些例子依赖`transformers`。
|
||||
|
||||
我们首先应该安装依赖:
|
||||
|
@ -77,8 +76,9 @@ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
```
|
||||
|
||||
然后我们定义一个损失函数:
|
||||
|
@ -182,16 +182,24 @@ def train_gemini_cpu(nvme_offload_fraction: float = 0.0):
|
|||
criterion = GPTLMLoss()
|
||||
optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction)
|
||||
print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B')
|
||||
gemini_config = dict(strict_ddp_mode=True, device=torch.cuda.current_device(),
|
||||
placement_policy='cpu', pin_memory=True, hidden_dim=config.n_embd)
|
||||
model = zero_model_wrapper(model, zero_stage=3, gemini_config=gemini_config)
|
||||
optimizer = zero_optim_wrapper(model, optimizer, initial_scale=2**5)
|
||||
|
||||
plugin = GeminiPlugin(
|
||||
strict_ddp_mode=True,
|
||||
device=torch.cuda.current_device(),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
hidden_dim=config.n_embd,
|
||||
initial_scale=2**5
|
||||
)
|
||||
booster = Booster(plugin)
|
||||
model, optimizer, criterion, _* = booster.boost(model, optimizer, criterion)
|
||||
|
||||
start = time.time()
|
||||
for step in range(3):
|
||||
data = get_data(4, 128, config.vocab_size)
|
||||
outputs = model(**data)
|
||||
loss = criterion(outputs.logits, data['input_ids'])
|
||||
optimizer.backward(loss)
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
print(f'[{step}] loss: {loss.item():.3f}')
|
||||
|
|
Loading…
Reference in New Issue