From 253e54d98ad1aa8a8cf73f0785270387449d3d26 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 9 Mar 2022 18:03:39 +0800 Subject: [PATCH] fix grad shape --- .../zero/sharded_model/sharded_model_v2.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index f6c5e10f5..a1172fdaa 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -17,8 +17,8 @@ from colossalai.zero.sharded_param import ShardedParamV2 from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor) -from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16) +from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, + get_gradient_predivide_factor) class ShardedModelV2(nn.Module): @@ -118,11 +118,13 @@ class ShardedModelV2(nn.Module): if not self._require_backward_grad_sync: continue # Write grad back to p.grad and set p.col_attr.grad to None - # We have to make sure grad and param have the same shape - # If world size > 1, and sharded param, `.view()` may be not needed - # If world size == 1, and sharded param, `data` is a flatten tensor - # But the shape `grad` is the same as unsharded param - p.grad.data = p.col_attr.grad.view(p.col_attr.data.shape) + # As sharded optimizer only update a shard of param, + # no matter whether we shard param in sharded model + # We have to make sure the grad is a flat tensor shard + # If world size == 1 and sharded param, + # the shape `grad` is the same as unsharded param + # So we can just use `view(-1)` to ensure grad is a flat tensor shard + p.grad.data = p.col_attr.grad.view(-1) p.col_attr.grad = None @torch.no_grad()