[feat] fix func name & ci; add comments;

pull/6034/head
duanjunwen 3 months ago
parent b5f7b4d228
commit 582ba0d6ff

@ -36,8 +36,8 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
return num_params, num_params_trainable
# Test iter input & multiple microbatch
def test_run_fwd_bwd_iter_input(
# Test manual v_schedule with multiple microbatch
def run_fwd_bwd_iter_input(
rank: int,
world_size: int,
port: int,
@ -474,8 +474,8 @@ def test_run_fwd_bwd_iter_input(
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
# T
def test_run_fwd_bwd_with_vschedule(
# Test v_schedule generated by graph with multiple microbatch
def run_fwd_bwd_with_vschedule(
rank: int,
world_size: int,
port: int,
@ -623,7 +623,7 @@ def test_run_fwd_bwd_with_vschedule(
@rerun_if_address_is_in_use()
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
spawn(
test_run_fwd_bwd_with_vschedule,
run_fwd_bwd_with_vschedule,
nprocs=4,
num_microbatch=num_microbatch,
batch_size=batch_size,

Loading…
Cancel
Save