[shardformer] hybridparallelplugin support gradients accumulation. (#5246)

* support gradients acc

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

* fix

fix

* fix

fix

fix
pull/5278/head^2
flybird11111 2024-01-17 15:22:33 +08:00 committed by GitHub
parent 2a0558d8ec
commit 46e091651b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 174 additions and 8 deletions

View File

@ -165,7 +165,6 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
Returns: Returns:
None None
""" """
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism: if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
if grads is not None: if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group. # Synchronize provided gradient tensors across the tensor parallelism group.
@ -487,7 +486,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
Returns: Returns:
None None
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs) super().backward(loss, *args, **kwargs)
@ -513,7 +511,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
Returns: Returns:
None None
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad) super().backward_by_grad(tensor, grad)
@ -674,7 +671,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
Returns: Returns:
None None
""" """
# Call the superclass `_sync_grad` method to synchronize gradients. # Call the superclass `_sync_grad` method to synchronize gradients.
super()._sync_grad() super()._sync_grad()
@ -1081,7 +1077,7 @@ class HybridParallelPlugin(PipelinePluginBase):
return True return True
def support_no_sync(self) -> bool: def support_no_sync(self) -> bool:
return False return True
def control_checkpoint_io(self) -> bool: def control_checkpoint_io(self) -> bool:
return True return True
@ -1175,9 +1171,14 @@ class HybridParallelPlugin(PipelinePluginBase):
model, data_iter, criterion, optimizer, return_loss, return_outputs model, data_iter, criterion, optimizer, return_loss, return_outputs
) )
# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
):
return outputs
# Synchronize the grads of shared parameters of the model. # Synchronize the grads of shared parameters of the model.
model.sync_shared_params() model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model. # Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads() model.sync_sp_grads()
@ -1241,5 +1242,8 @@ class HybridParallelPlugin(PipelinePluginBase):
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
def no_sync(self, model: Module) -> Iterator[None]: def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError assert (
self.zero_stage != 2
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()

View File

@ -1,8 +1,11 @@
import copy
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.testing import assert_close
from torch.utils.data import Dataset
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
@ -11,9 +14,33 @@ from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device, set_seed
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 100, max_length: int = 512, vocab_size: int = 32000):
self.num_samples = num_samples
self.max_length = max_length
set_seed(42)
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}
@clear_cache_before_run() @clear_cache_before_run()
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
try: try:
@ -85,10 +112,145 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
@parameterize(
"test_args",
[
{
"batch_size": 8,
"num_steps": 4,
"tp": 2,
"pp": 2,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 4,
"zero": 0,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{
"batch_size": 8,
"num_steps": 4,
"tp": 1,
"pp": 2,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 4,
"zero": 1,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{
"batch_size": 1,
"num_steps": 4,
"tp": 2,
"pp": 1,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 1,
"zero": 2,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{
"batch_size": 1,
"num_steps": 4,
"tp": 2,
"pp": 1,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 1,
"zero": 0,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
],
)
def run_grad_acc_test(test_args):
model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()))
model = model_fn()
optimizer = HybridAdam(model.parameters())
origin_model = copy.deepcopy(model).cuda()
origin_optimizer = HybridAdam(origin_model.parameters())
plugin = HybridParallelPlugin(
tp_size=test_args["tp"],
pp_size=test_args["pp"],
pp_style=test_args["pp_style"],
zero_stage=test_args["zero"],
num_model_chunks=test_args["num_model_chunks"],
enable_fused_normalization=True,
num_microbatches=test_args["num_microbatches"],
precision=test_args["precision"],
)
booster = Booster(plugin=plugin)
dataset = RandomDataset(
num_samples=test_args["batch_size"] * test_args["num_steps"] * plugin.dp_size,
max_length=test_args["max_length"],
vocab_size=model.config.vocab_size,
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=test_args["batch_size"], shuffle=True, drop_last=True)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
grad_accu_step = test_args["gradient_accumulation_step"]
for step, batch in enumerate(dataloader):
batch = move_to_cuda(batch)
# train origin model
origin_output = origin_model(**batch)
origin_loss = origin_output[0] / grad_accu_step
origin_loss.backward()
if (step + 1) % grad_accu_step != 0 and test_args["zero"] != 2:
ctx = booster.no_sync(model, optimizer)
else:
ctx = nullcontext()
with ctx:
if plugin.stage_manager is not None:
batch = iter([batch])
booster.execute_pipeline(
batch,
model,
criterion=lambda outputs, inputs: outputs[0] / grad_accu_step,
optimizer=optimizer,
return_loss=False,
)
else:
outputs = model(**batch)
loss = outputs[0] / grad_accu_step
booster.backward(loss, optimizer)
if (step + 1) % grad_accu_step == 0:
# update origin model weight
origin_optimizer.step()
origin_optimizer.zero_grad()
# update sharded model
optimizer.step()
optimizer.zero_grad()
# tricky code here, shard the origin model inorder to check the parameters in the same stage.
origin_model, origin_optimizer, _, dataloader, _ = booster.boost(
origin_model, origin_optimizer, dataloader=dataloader
)
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
def run_dist(rank, world_size, port, early_stop: bool = True): def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env # init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
check_3d_plugin(early_stop=early_stop) check_3d_plugin(early_stop=early_stop)
run_grad_acc_test()
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()