[hotfix] fix reuse_fp16_shard of sharded model (#756)

* fix reuse_fp16_shard

* disable test stm

* polish code
pull/754/head
ver217 2022-04-14 14:56:46 +08:00 committed by GitHub
parent 8f7ce94b8e
commit a93a7d7364
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 7 deletions

View File

@ -253,9 +253,6 @@ class ShardedModelV2(nn.Module):
with torch.cuda.stream(self.comm_stream):
self.reducer.flush()
torch.cuda.current_stream().wait_stream(self.comm_stream)
if self._cpu_offload:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
self.reducer.free()
# 3. shard tensors not dealed in the zero hook
@ -338,7 +335,7 @@ class ShardedModelV2(nn.Module):
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
assert isinstance(reduced_grad,
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
reduced_grad.data = reduced_grad.data.view(-1)
reduced_grad.data = reduced_grad.data.contiguous().view(-1)
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
@ -362,7 +359,7 @@ class ShardedModelV2(nn.Module):
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.reset_grad_payload(grad)
param.colo_attr.reset_grad_payload(grad) # release the memory of param
param.colo_attr.reset_data_payload(grad) # release the memory of param
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True

View File

@ -70,7 +70,6 @@ class ShardedParamV2(object):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.sharded_data_tensor.reset_payload(tensor)
self.set_data_none()
def reset_grad_payload(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor

View File

@ -112,7 +112,7 @@ def run_dist(rank, world_size, port):
run_stm()
@pytest.mark.dist
@pytest.mark.skip
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_stateful_tensor_manager(world_size=1):
run_func = partial(run_dist, world_size=world_size, port=free_port())