fix grad shape

pull/394/head
ver217 2022-03-09 18:03:39 +08:00 committed by Frank Lee
parent ea2872073f
commit 253e54d98a
1 changed files with 9 additions and 7 deletions

View File

@ -17,8 +17,8 @@ from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor) from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16) get_gradient_predivide_factor)
class ShardedModelV2(nn.Module): class ShardedModelV2(nn.Module):
@ -118,11 +118,13 @@ class ShardedModelV2(nn.Module):
if not self._require_backward_grad_sync: if not self._require_backward_grad_sync:
continue continue
# Write grad back to p.grad and set p.col_attr.grad to None # Write grad back to p.grad and set p.col_attr.grad to None
# We have to make sure grad and param have the same shape # As sharded optimizer only update a shard of param,
# If world size > 1, and sharded param, `.view()` may be not needed # no matter whether we shard param in sharded model
# If world size == 1, and sharded param, `data` is a flatten tensor # We have to make sure the grad is a flat tensor shard
# But the shape `grad` is the same as unsharded param # If world size == 1 and sharded param,
p.grad.data = p.col_attr.grad.view(p.col_attr.data.shape) # the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
p.grad.data = p.col_attr.grad.view(-1)
p.col_attr.grad = None p.col_attr.grad = None
@torch.no_grad() @torch.no_grad()