|
|
@ -36,8 +36,8 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
|
|
|
return num_params, num_params_trainable
|
|
|
|
return num_params, num_params_trainable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test iter input & multiple microbatch
|
|
|
|
# Test manual v_schedule with multiple microbatch
|
|
|
|
def test_run_fwd_bwd_iter_input(
|
|
|
|
def run_fwd_bwd_iter_input(
|
|
|
|
rank: int,
|
|
|
|
rank: int,
|
|
|
|
world_size: int,
|
|
|
|
world_size: int,
|
|
|
|
port: int,
|
|
|
|
port: int,
|
|
|
@ -474,8 +474,8 @@ def test_run_fwd_bwd_iter_input(
|
|
|
|
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
|
|
|
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# T
|
|
|
|
# Test v_schedule generated by graph with multiple microbatch
|
|
|
|
def test_run_fwd_bwd_with_vschedule(
|
|
|
|
def run_fwd_bwd_with_vschedule(
|
|
|
|
rank: int,
|
|
|
|
rank: int,
|
|
|
|
world_size: int,
|
|
|
|
world_size: int,
|
|
|
|
port: int,
|
|
|
|
port: int,
|
|
|
@ -623,7 +623,7 @@ def test_run_fwd_bwd_with_vschedule(
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
|
|
|
|
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
|
|
|
|
spawn(
|
|
|
|
spawn(
|
|
|
|
test_run_fwd_bwd_with_vschedule,
|
|
|
|
run_fwd_bwd_with_vschedule,
|
|
|
|
nprocs=4,
|
|
|
|
nprocs=4,
|
|
|
|
num_microbatch=num_microbatch,
|
|
|
|
num_microbatch=num_microbatch,
|
|
|
|
batch_size=batch_size,
|
|
|
|
batch_size=batch_size,
|
|
|
|