From 839847b7d78bce6af5dfe58d27b5ce2c74a3619b Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Fri, 25 Aug 2023 13:44:07 +0800 Subject: [PATCH] [zero]support zero2 with gradient accumulation (#4511) * support gradient accumulation with zero2 * fix type --- .../low_level/bookkeeping/gradient_store.py | 4 +- colossalai/zero/low_level/low_level_optim.py | 13 ++++-- colossalai/zero/low_level/readme.md | 44 +++++++++++++++++-- .../test_zero/test_low_level/test_grad_acc.py | 28 ++++-------- 4 files changed, 61 insertions(+), 28 deletions(-) diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 0b86ec8ca..2890b329a 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -57,8 +57,8 @@ class GradientStore(BaseStore): self._grads_of_params[group_id][param_id].append(grad) def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int): - """For old gradient accumulation, not in use now. - Add a gradient slice on an existing slice of the parameter's gradient + """Add a gradient slice on an existing slice of the parameter's gradient + Used when no_sync is not activated. Args: grad (Tensor): The split gradient to append to list diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 64d6a5395..8f2232393 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -277,7 +277,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): sync_tensor(flat_grads_per_rank[rank], grad_list) for grad in grad_list: param_id = self._bucket_store.get_param_id_of_grad(grad) - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, + param_id)) < self._world_size: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) else: flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) @@ -291,7 +295,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): sync_tensor(recieved_grad, grad_in_bucket_current_rank) for grad in grad_in_bucket_current_rank: param_id = self._bucket_store.get_param_id_of_grad(grad) - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) self._bucket_store.reset() @@ -315,7 +322,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def backward(self, loss, retain_graph=False): assert not(self._partition_grads and not self.require_grad_sync), \ - "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + "ZeRO2(partition_grads) and no_sync are not compatible" if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) diff --git a/colossalai/zero/low_level/readme.md b/colossalai/zero/low_level/readme.md index aa92159d8..b960a4362 100644 --- a/colossalai/zero/low_level/readme.md +++ b/colossalai/zero/low_level/readme.md @@ -1,5 +1,41 @@ # Low Level ZeRO >Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO. +## Examples of ZeRO and gradient accumulation + +The code below only shows a typical gradient accumulation process, and it drops a lot of details, such as the processing of loss. + +```python +# examples of ZeRO1 with gradient accumulation +... +outputs = model(input) +loss = SomeLoss(outputs) +if (idx + 1) % ACCUMULATE_STEP != 0: + with booster.no_sync(model, optimizer): + # under this context, the gradient would not sync when backward, + # left each rank having different gradient. + # It saves the backward time + booster.backward(loss, optimizer) + continue +else: + # need to sync all the accumulated gradient + booster.backward(loss, optimizer): + optimizer.step() + ... +``` + +```python +# example of ZeRO2 with gradient accumulation + +... +outputs = model(input) +loss = SomeLoss(outputs) +# ZeRO2 split the gradients and can NOT accumulate gradient with syncing. +booster.backward(loss, optimizer) +if (idx + 1) % ACCUMULATE_STEP == 0: + optimizer.step() +... +``` + ## Design: ### Notion @@ -25,11 +61,11 @@ The data structure looks like this: ``` After that, the gradients would be flattened by rank, and the data structure looks like this: ``` -# g-0 means flatten([g-00, g-10]) +# g-X0 means flatten([g-00, g-10]) { -0: [g-0], -1: [g-1], -2: [g-2] +0: [g-X0], +1: [g-X1], +2: [g-X2] } ``` For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`. diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index a1d14f1d5..f170f7cb8 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -58,17 +58,8 @@ def exam_zero_1_2_grad_acc(): assert torch.equal(zero1_output, zero2_output) # zero-dp backward - no_sync = number == 0 - with conditional_context(zero1_optimizer.no_sync(), no_sync): - zero1_optimizer.backward(zero1_output.sum().float()) - with conditional_context(zero2_optimizer.no_sync(), no_sync): - zero2_optimizer.backward(zero2_output.sum().float()) - - if check_flag: - for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): - if z2p.grad is not None: - # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) - assert torch.equal(z1p.grad, z2p.grad) + zero1_optimizer.backward(zero1_output.sum().float()) + zero2_optimizer.backward(zero2_output.sum().float()) fwd_bwd_func(0, input_data1, True) fwd_bwd_func(1, input_data2, False) @@ -82,7 +73,7 @@ def exam_zero_1_2_grad_acc(): assert torch.equal(z1p.data, z2p.data) -def exam_zero_1_grad_acc(): +def exam_zero_1_grad_acc(sync): local_rank = torch.distributed.get_rank() seed_all(2008) @@ -112,9 +103,8 @@ def exam_zero_1_grad_acc(): input_data1 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128).cuda() - def fwd_bwd_func(number, cur_data, check_flag): + def fwd_bwd_func(no_sync, cur_data, check_flag): - no_sync = number == 0 # zero1 fwd and bwd with conditional_context(zero_optimizer.no_sync(), no_sync): zero_output = zero_model(cur_data) @@ -131,8 +121,8 @@ def exam_zero_1_grad_acc(): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): assert torch.equal(p.grad, z1p.grad) - fwd_bwd_func(0, input_data1, True) - fwd_bwd_func(1, input_data2, False) + fwd_bwd_func(sync, input_data1, sync) + fwd_bwd_func(False, input_data2, False) zero_optimizer.step() torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) @@ -147,9 +137,9 @@ def exam_zero_1_grad_acc(): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - exam_zero_1_grad_acc() - # gradient accumulation is not compatible with ZeRO-2 - # exam_zero_1_2_grad_acc() + exam_zero_1_grad_acc(sync=True) + exam_zero_1_grad_acc(sync=False) + exam_zero_1_2_grad_acc() @pytest.mark.dist