mirror of https://github.com/hpcaitech/ColossalAI
[test] fix test: test_zero1_2
parent
c67e553fd3
commit
e31d2ebcf7
|
@ -880,7 +880,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
return None
|
return None
|
||||||
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
|
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
|
||||||
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
|
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
|
||||||
return grad_flat[: working_param.numel()].reshape_as(working_param)
|
return grad_flat.view(-1)[: working_param.numel()].view_as(working_param)
|
||||||
|
|
||||||
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
|
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
|
||||||
working_grads = []
|
working_grads = []
|
||||||
|
|
|
@ -179,9 +179,9 @@ def run_dist(rank, world_size, port):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [4])
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_mistral(world_size):
|
def test_deepseek(world_size):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_mistral(world_size=8)
|
test_deepseek(world_size=4)
|
||||||
|
|
|
@ -180,9 +180,9 @@ def run_dist(rank, world_size, port):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [4])
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_mistral(world_size):
|
def test_mixtral(world_size):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_mistral(world_size=8)
|
test_mixtral(world_size=4)
|
||||||
|
|
Loading…
Reference in New Issue