mirror of https://github.com/hpcaitech/ColossalAI
fix grad shape
parent
ea2872073f
commit
253e54d98a
|
@ -17,8 +17,8 @@ from colossalai.zero.sharded_param import ShardedParamV2
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor)
|
||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16)
|
||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
||||
get_gradient_predivide_factor)
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
|
@ -118,11 +118,13 @@ class ShardedModelV2(nn.Module):
|
|||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
# Write grad back to p.grad and set p.col_attr.grad to None
|
||||
# We have to make sure grad and param have the same shape
|
||||
# If world size > 1, and sharded param, `.view()` may be not needed
|
||||
# If world size == 1, and sharded param, `data` is a flatten tensor
|
||||
# But the shape `grad` is the same as unsharded param
|
||||
p.grad.data = p.col_attr.grad.view(p.col_attr.data.shape)
|
||||
# As sharded optimizer only update a shard of param,
|
||||
# no matter whether we shard param in sharded model
|
||||
# We have to make sure the grad is a flat tensor shard
|
||||
# If world size == 1 and sharded param,
|
||||
# the shape `grad` is the same as unsharded param
|
||||
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
||||
p.grad.data = p.col_attr.grad.view(-1)
|
||||
p.col_attr.grad = None
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
Loading…
Reference in New Issue