[test] fix test: test_zero1_2

colossalchat
hxwang 4 months ago committed by Hongxin Liu
parent 74b03de3f9
commit 067e18f7e9

@ -880,7 +880,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return None return None
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device) grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg) dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
return grad_flat[: working_param.numel()].reshape_as(working_param) return grad_flat.view(-1)[: working_param.numel()].view_as(working_param)
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]: def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
working_grads = [] working_grads = []

@ -179,9 +179,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_mistral(world_size): def test_deepseek(world_size):
spawn(run_dist, world_size) spawn(run_dist, world_size)
if __name__ == "__main__": if __name__ == "__main__":
test_mistral(world_size=8) test_deepseek(world_size=4)

@ -180,9 +180,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_mistral(world_size): def test_mixtral(world_size):
spawn(run_dist, world_size) spawn(run_dist, world_size)
if __name__ == "__main__": if __name__ == "__main__":
test_mistral(world_size=8) test_mixtral(world_size=4)

Loading…
Cancel
Save