mirror of https://github.com/hpcaitech/ColossalAI
[zero]support zero2 with gradient accumulation (#4511)
* support gradient accumulation with zero2 * fix typepull/4546/head
parent
c0efc3ebcb
commit
839847b7d7
|
@ -57,8 +57,8 @@ class GradientStore(BaseStore):
|
|||
self._grads_of_params[group_id][param_id].append(grad)
|
||||
|
||||
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
||||
"""For old gradient accumulation, not in use now.
|
||||
Add a gradient slice on an existing slice of the parameter's gradient
|
||||
"""Add a gradient slice on an existing slice of the parameter's gradient
|
||||
Used when no_sync is not activated.
|
||||
|
||||
Args:
|
||||
grad (Tensor): The split gradient to append to list
|
||||
|
|
|
@ -277,7 +277,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
sync_tensor(flat_grads_per_rank[rank], grad_list)
|
||||
for grad in grad_list:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id,
|
||||
param_id)) < self._world_size:
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
|
||||
|
||||
else:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
||||
|
@ -291,7 +295,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
|
||||
for grad in grad_in_bucket_current_rank:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1:
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id)
|
||||
|
||||
self._bucket_store.reset()
|
||||
|
||||
|
@ -315,7 +322,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
"ZeRO2(partition_grads) and no_sync are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
|
|
|
@ -1,5 +1,41 @@
|
|||
# Low Level ZeRO
|
||||
>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.
|
||||
## Examples of ZeRO and gradient accumulation
|
||||
|
||||
The code below only shows a typical gradient accumulation process, and it drops a lot of details, such as the processing of loss.
|
||||
|
||||
```python
|
||||
# examples of ZeRO1 with gradient accumulation
|
||||
...
|
||||
outputs = model(input)
|
||||
loss = SomeLoss(outputs)
|
||||
if (idx + 1) % ACCUMULATE_STEP != 0:
|
||||
with booster.no_sync(model, optimizer):
|
||||
# under this context, the gradient would not sync when backward,
|
||||
# left each rank having different gradient.
|
||||
# It saves the backward time
|
||||
booster.backward(loss, optimizer)
|
||||
continue
|
||||
else:
|
||||
# need to sync all the accumulated gradient
|
||||
booster.backward(loss, optimizer):
|
||||
optimizer.step()
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
# example of ZeRO2 with gradient accumulation
|
||||
|
||||
...
|
||||
outputs = model(input)
|
||||
loss = SomeLoss(outputs)
|
||||
# ZeRO2 split the gradients and can NOT accumulate gradient with syncing.
|
||||
booster.backward(loss, optimizer)
|
||||
if (idx + 1) % ACCUMULATE_STEP == 0:
|
||||
optimizer.step()
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
## Design:
|
||||
### Notion
|
||||
|
@ -25,11 +61,11 @@ The data structure looks like this:
|
|||
```
|
||||
After that, the gradients would be flattened by rank, and the data structure looks like this:
|
||||
```
|
||||
# g-0 means flatten([g-00, g-10])
|
||||
# g-X0 means flatten([g-00, g-10])
|
||||
{
|
||||
0: [g-0],
|
||||
1: [g-1],
|
||||
2: [g-2]
|
||||
0: [g-X0],
|
||||
1: [g-X1],
|
||||
2: [g-X2]
|
||||
}
|
||||
```
|
||||
For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`.
|
||||
|
|
|
@ -58,17 +58,8 @@ def exam_zero_1_2_grad_acc():
|
|||
assert torch.equal(zero1_output, zero2_output)
|
||||
|
||||
# zero-dp backward
|
||||
no_sync = number == 0
|
||||
with conditional_context(zero1_optimizer.no_sync(), no_sync):
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
with conditional_context(zero2_optimizer.no_sync(), no_sync):
|
||||
zero2_optimizer.backward(zero2_output.sum().float())
|
||||
|
||||
if check_flag:
|
||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||
if z2p.grad is not None:
|
||||
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
||||
assert torch.equal(z1p.grad, z2p.grad)
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
zero2_optimizer.backward(zero2_output.sum().float())
|
||||
|
||||
fwd_bwd_func(0, input_data1, True)
|
||||
fwd_bwd_func(1, input_data2, False)
|
||||
|
@ -82,7 +73,7 @@ def exam_zero_1_2_grad_acc():
|
|||
assert torch.equal(z1p.data, z2p.data)
|
||||
|
||||
|
||||
def exam_zero_1_grad_acc():
|
||||
def exam_zero_1_grad_acc(sync):
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(2008)
|
||||
|
||||
|
@ -112,9 +103,8 @@ def exam_zero_1_grad_acc():
|
|||
input_data1 = torch.randn(32, 128).cuda()
|
||||
input_data2 = torch.randn(32, 128).cuda()
|
||||
|
||||
def fwd_bwd_func(number, cur_data, check_flag):
|
||||
def fwd_bwd_func(no_sync, cur_data, check_flag):
|
||||
|
||||
no_sync = number == 0
|
||||
# zero1 fwd and bwd
|
||||
with conditional_context(zero_optimizer.no_sync(), no_sync):
|
||||
zero_output = zero_model(cur_data)
|
||||
|
@ -131,8 +121,8 @@ def exam_zero_1_grad_acc():
|
|||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
assert torch.equal(p.grad, z1p.grad)
|
||||
|
||||
fwd_bwd_func(0, input_data1, True)
|
||||
fwd_bwd_func(1, input_data2, False)
|
||||
fwd_bwd_func(sync, input_data1, sync)
|
||||
fwd_bwd_func(False, input_data2, False)
|
||||
|
||||
zero_optimizer.step()
|
||||
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
||||
|
@ -147,9 +137,9 @@ def exam_zero_1_grad_acc():
|
|||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
exam_zero_1_grad_acc()
|
||||
# gradient accumulation is not compatible with ZeRO-2
|
||||
# exam_zero_1_2_grad_acc()
|
||||
exam_zero_1_grad_acc(sync=True)
|
||||
exam_zero_1_grad_acc(sync=False)
|
||||
exam_zero_1_2_grad_acc()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue