mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] hybridparallelplugin support gradients accumulation. (#5246)
* support gradients acc fix fix fix fix fix fix fix fix fix fix fix fix fix * fix fix * fix fix fixpull/5278/head^2
parent
2a0558d8ec
commit
46e091651b
|
@ -165,7 +165,6 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
|
||||
if grads is not None:
|
||||
# Synchronize provided gradient tensors across the tensor parallelism group.
|
||||
|
@ -487,7 +486,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, *args, **kwargs)
|
||||
|
||||
|
@ -513,7 +511,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
|
@ -674,7 +671,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass `_sync_grad` method to synchronize gradients.
|
||||
super()._sync_grad()
|
||||
|
||||
|
@ -1081,7 +1077,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return True
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
@ -1175,9 +1171,14 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
||||
# run with gradients accumulation
|
||||
if model.require_grad_sync == False or (
|
||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
|
||||
):
|
||||
return outputs
|
||||
|
||||
# Synchronize the grads of shared parameters of the model.
|
||||
model.sync_shared_params()
|
||||
|
||||
# Synchronize sequence parallelism gradients of the model.
|
||||
model.sync_sp_grads()
|
||||
|
||||
|
@ -1241,5 +1242,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
|
||||
def no_sync(self, model: Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert (
|
||||
self.zero_stage != 2
|
||||
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
|
||||
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import copy
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
|
@ -11,9 +14,33 @@ from colossalai.fx import is_compatible_with_meta
|
|||
from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device, set_seed
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples: int = 100, max_length: int = 512, vocab_size: int = 32000):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
set_seed(42)
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
||||
|
||||
|
||||
def move_to_cuda(batch):
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
try:
|
||||
|
@ -85,10 +112,145 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
|
|||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_args",
|
||||
[
|
||||
{
|
||||
"batch_size": 8,
|
||||
"num_steps": 4,
|
||||
"tp": 2,
|
||||
"pp": 2,
|
||||
"pp_style": "1f1b",
|
||||
"num_model_chunks": 1,
|
||||
"num_microbatches": 4,
|
||||
"zero": 0,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"max_length": 512,
|
||||
"gradient_accumulation_step": 2,
|
||||
},
|
||||
{
|
||||
"batch_size": 8,
|
||||
"num_steps": 4,
|
||||
"tp": 1,
|
||||
"pp": 2,
|
||||
"pp_style": "1f1b",
|
||||
"num_model_chunks": 1,
|
||||
"num_microbatches": 4,
|
||||
"zero": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"max_length": 512,
|
||||
"gradient_accumulation_step": 2,
|
||||
},
|
||||
{
|
||||
"batch_size": 1,
|
||||
"num_steps": 4,
|
||||
"tp": 2,
|
||||
"pp": 1,
|
||||
"pp_style": "1f1b",
|
||||
"num_model_chunks": 1,
|
||||
"num_microbatches": 1,
|
||||
"zero": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"max_length": 512,
|
||||
"gradient_accumulation_step": 2,
|
||||
},
|
||||
{
|
||||
"batch_size": 1,
|
||||
"num_steps": 4,
|
||||
"tp": 2,
|
||||
"pp": 1,
|
||||
"pp_style": "1f1b",
|
||||
"num_model_chunks": 1,
|
||||
"num_microbatches": 1,
|
||||
"zero": 0,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"max_length": 512,
|
||||
"gradient_accumulation_step": 2,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_grad_acc_test(test_args):
|
||||
model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()))
|
||||
model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
origin_model = copy.deepcopy(model).cuda()
|
||||
origin_optimizer = HybridAdam(origin_model.parameters())
|
||||
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=test_args["tp"],
|
||||
pp_size=test_args["pp"],
|
||||
pp_style=test_args["pp_style"],
|
||||
zero_stage=test_args["zero"],
|
||||
num_model_chunks=test_args["num_model_chunks"],
|
||||
enable_fused_normalization=True,
|
||||
num_microbatches=test_args["num_microbatches"],
|
||||
precision=test_args["precision"],
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
dataset = RandomDataset(
|
||||
num_samples=test_args["batch_size"] * test_args["num_steps"] * plugin.dp_size,
|
||||
max_length=test_args["max_length"],
|
||||
vocab_size=model.config.vocab_size,
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=test_args["batch_size"], shuffle=True, drop_last=True)
|
||||
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
|
||||
grad_accu_step = test_args["gradient_accumulation_step"]
|
||||
for step, batch in enumerate(dataloader):
|
||||
batch = move_to_cuda(batch)
|
||||
# train origin model
|
||||
origin_output = origin_model(**batch)
|
||||
origin_loss = origin_output[0] / grad_accu_step
|
||||
origin_loss.backward()
|
||||
|
||||
if (step + 1) % grad_accu_step != 0 and test_args["zero"] != 2:
|
||||
ctx = booster.no_sync(model, optimizer)
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
|
||||
with ctx:
|
||||
if plugin.stage_manager is not None:
|
||||
batch = iter([batch])
|
||||
booster.execute_pipeline(
|
||||
batch,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0] / grad_accu_step,
|
||||
optimizer=optimizer,
|
||||
return_loss=False,
|
||||
)
|
||||
else:
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0] / grad_accu_step
|
||||
booster.backward(loss, optimizer)
|
||||
|
||||
if (step + 1) % grad_accu_step == 0:
|
||||
# update origin model weight
|
||||
origin_optimizer.step()
|
||||
origin_optimizer.zero_grad()
|
||||
|
||||
# update sharded model
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# tricky code here, shard the origin model inorder to check the parameters in the same stage.
|
||||
origin_model, origin_optimizer, _, dataloader, _ = booster.boost(
|
||||
origin_model, origin_optimizer, dataloader=dataloader
|
||||
)
|
||||
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
|
||||
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_3d_plugin(early_stop=early_stop)
|
||||
run_grad_acc_test()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
Loading…
Reference in New Issue