mirror of https://github.com/hpcaitech/ColossalAI
fix grad shape
parent
ea2872073f
commit
253e54d98a
|
@ -17,8 +17,8 @@ from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor)
|
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
||||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16)
|
get_gradient_predivide_factor)
|
||||||
|
|
||||||
|
|
||||||
class ShardedModelV2(nn.Module):
|
class ShardedModelV2(nn.Module):
|
||||||
|
@ -118,11 +118,13 @@ class ShardedModelV2(nn.Module):
|
||||||
if not self._require_backward_grad_sync:
|
if not self._require_backward_grad_sync:
|
||||||
continue
|
continue
|
||||||
# Write grad back to p.grad and set p.col_attr.grad to None
|
# Write grad back to p.grad and set p.col_attr.grad to None
|
||||||
# We have to make sure grad and param have the same shape
|
# As sharded optimizer only update a shard of param,
|
||||||
# If world size > 1, and sharded param, `.view()` may be not needed
|
# no matter whether we shard param in sharded model
|
||||||
# If world size == 1, and sharded param, `data` is a flatten tensor
|
# We have to make sure the grad is a flat tensor shard
|
||||||
# But the shape `grad` is the same as unsharded param
|
# If world size == 1 and sharded param,
|
||||||
p.grad.data = p.col_attr.grad.view(p.col_attr.data.shape)
|
# the shape `grad` is the same as unsharded param
|
||||||
|
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
||||||
|
p.grad.data = p.col_attr.grad.view(-1)
|
||||||
p.col_attr.grad = None
|
p.col_attr.grad = None
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
Loading…
Reference in New Issue