mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] hybridparallelplugin support gradients accumulation. (#5246)
* support gradients acc fix fix fix fix fix fix fix fix fix fix fix fix fix * fix fix * fix fix fixpull/5278/head^2
parent
2a0558d8ec
commit
46e091651b
|
@ -165,7 +165,6 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
|
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
|
||||||
if grads is not None:
|
if grads is not None:
|
||||||
# Synchronize provided gradient tensors across the tensor parallelism group.
|
# Synchronize provided gradient tensors across the tensor parallelism group.
|
||||||
|
@ -487,7 +486,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward(loss, *args, **kwargs)
|
super().backward(loss, *args, **kwargs)
|
||||||
|
|
||||||
|
@ -513,7 +511,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward_by_grad(tensor, grad)
|
super().backward_by_grad(tensor, grad)
|
||||||
|
|
||||||
|
@ -674,7 +671,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Call the superclass `_sync_grad` method to synchronize gradients.
|
# Call the superclass `_sync_grad` method to synchronize gradients.
|
||||||
super()._sync_grad()
|
super()._sync_grad()
|
||||||
|
|
||||||
|
@ -1081,7 +1077,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
return False
|
return True
|
||||||
|
|
||||||
def control_checkpoint_io(self) -> bool:
|
def control_checkpoint_io(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
@ -1175,9 +1171,14 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# run with gradients accumulation
|
||||||
|
if model.require_grad_sync == False or (
|
||||||
|
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
|
||||||
|
):
|
||||||
|
return outputs
|
||||||
|
|
||||||
# Synchronize the grads of shared parameters of the model.
|
# Synchronize the grads of shared parameters of the model.
|
||||||
model.sync_shared_params()
|
model.sync_shared_params()
|
||||||
|
|
||||||
# Synchronize sequence parallelism gradients of the model.
|
# Synchronize sequence parallelism gradients of the model.
|
||||||
model.sync_sp_grads()
|
model.sync_sp_grads()
|
||||||
|
|
||||||
|
@ -1241,5 +1242,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
def get_checkpoint_io(self) -> CheckpointIO:
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||||
|
|
||||||
def no_sync(self, model: Module) -> Iterator[None]:
|
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
raise NotImplementedError
|
assert (
|
||||||
|
self.zero_stage != 2
|
||||||
|
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
|
||||||
|
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
|
import copy
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from torch.testing import assert_close
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
|
@ -11,9 +14,33 @@ from colossalai.fx import is_compatible_with_meta
|
||||||
from colossalai.lazy.lazy_init import LazyInitContext
|
from colossalai.lazy.lazy_init import LazyInitContext
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from colossalai.utils import get_current_device, set_seed
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
class RandomDataset(Dataset):
|
||||||
|
def __init__(self, num_samples: int = 100, max_length: int = 512, vocab_size: int = 32000):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.max_length = max_length
|
||||||
|
set_seed(42)
|
||||||
|
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||||
|
self.attention_mask = torch.ones_like(self.input_ids)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return {
|
||||||
|
"input_ids": self.input_ids[idx],
|
||||||
|
"attention_mask": self.attention_mask[idx],
|
||||||
|
"labels": self.input_ids[idx],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def move_to_cuda(batch):
|
||||||
|
return {k: v.cuda() for k, v in batch.items()}
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
|
@ -85,10 +112,145 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
|
||||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_args",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"batch_size": 8,
|
||||||
|
"num_steps": 4,
|
||||||
|
"tp": 2,
|
||||||
|
"pp": 2,
|
||||||
|
"pp_style": "1f1b",
|
||||||
|
"num_model_chunks": 1,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"zero": 0,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
"max_length": 512,
|
||||||
|
"gradient_accumulation_step": 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"batch_size": 8,
|
||||||
|
"num_steps": 4,
|
||||||
|
"tp": 1,
|
||||||
|
"pp": 2,
|
||||||
|
"pp_style": "1f1b",
|
||||||
|
"num_model_chunks": 1,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"zero": 1,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
"max_length": 512,
|
||||||
|
"gradient_accumulation_step": 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"batch_size": 1,
|
||||||
|
"num_steps": 4,
|
||||||
|
"tp": 2,
|
||||||
|
"pp": 1,
|
||||||
|
"pp_style": "1f1b",
|
||||||
|
"num_model_chunks": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"zero": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
"max_length": 512,
|
||||||
|
"gradient_accumulation_step": 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"batch_size": 1,
|
||||||
|
"num_steps": 4,
|
||||||
|
"tp": 2,
|
||||||
|
"pp": 1,
|
||||||
|
"pp_style": "1f1b",
|
||||||
|
"num_model_chunks": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"zero": 0,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
"max_length": 512,
|
||||||
|
"gradient_accumulation_step": 2,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_grad_acc_test(test_args):
|
||||||
|
model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()))
|
||||||
|
model = model_fn()
|
||||||
|
optimizer = HybridAdam(model.parameters())
|
||||||
|
origin_model = copy.deepcopy(model).cuda()
|
||||||
|
origin_optimizer = HybridAdam(origin_model.parameters())
|
||||||
|
|
||||||
|
plugin = HybridParallelPlugin(
|
||||||
|
tp_size=test_args["tp"],
|
||||||
|
pp_size=test_args["pp"],
|
||||||
|
pp_style=test_args["pp_style"],
|
||||||
|
zero_stage=test_args["zero"],
|
||||||
|
num_model_chunks=test_args["num_model_chunks"],
|
||||||
|
enable_fused_normalization=True,
|
||||||
|
num_microbatches=test_args["num_microbatches"],
|
||||||
|
precision=test_args["precision"],
|
||||||
|
)
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
dataset = RandomDataset(
|
||||||
|
num_samples=test_args["batch_size"] * test_args["num_steps"] * plugin.dp_size,
|
||||||
|
max_length=test_args["max_length"],
|
||||||
|
vocab_size=model.config.vocab_size,
|
||||||
|
)
|
||||||
|
dataloader = plugin.prepare_dataloader(dataset, batch_size=test_args["batch_size"], shuffle=True, drop_last=True)
|
||||||
|
|
||||||
|
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||||
|
|
||||||
|
grad_accu_step = test_args["gradient_accumulation_step"]
|
||||||
|
for step, batch in enumerate(dataloader):
|
||||||
|
batch = move_to_cuda(batch)
|
||||||
|
# train origin model
|
||||||
|
origin_output = origin_model(**batch)
|
||||||
|
origin_loss = origin_output[0] / grad_accu_step
|
||||||
|
origin_loss.backward()
|
||||||
|
|
||||||
|
if (step + 1) % grad_accu_step != 0 and test_args["zero"] != 2:
|
||||||
|
ctx = booster.no_sync(model, optimizer)
|
||||||
|
else:
|
||||||
|
ctx = nullcontext()
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
if plugin.stage_manager is not None:
|
||||||
|
batch = iter([batch])
|
||||||
|
booster.execute_pipeline(
|
||||||
|
batch,
|
||||||
|
model,
|
||||||
|
criterion=lambda outputs, inputs: outputs[0] / grad_accu_step,
|
||||||
|
optimizer=optimizer,
|
||||||
|
return_loss=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = model(**batch)
|
||||||
|
loss = outputs[0] / grad_accu_step
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
|
||||||
|
if (step + 1) % grad_accu_step == 0:
|
||||||
|
# update origin model weight
|
||||||
|
origin_optimizer.step()
|
||||||
|
origin_optimizer.zero_grad()
|
||||||
|
|
||||||
|
# update sharded model
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# tricky code here, shard the origin model inorder to check the parameters in the same stage.
|
||||||
|
origin_model, origin_optimizer, _, dataloader, _ = booster.boost(
|
||||||
|
origin_model, origin_optimizer, dataloader=dataloader
|
||||||
|
)
|
||||||
|
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
|
||||||
|
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
check_3d_plugin(early_stop=early_stop)
|
check_3d_plugin(early_stop=early_stop)
|
||||||
|
run_grad_acc_test()
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
Loading…
Reference in New Issue