mirror of https://github.com/hpcaitech/ColossalAI
[feat] fix func name & ci; add comments;
parent
b5f7b4d228
commit
582ba0d6ff
|
@ -36,8 +36,8 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
return num_params, num_params_trainable
|
return num_params, num_params_trainable
|
||||||
|
|
||||||
|
|
||||||
# Test iter input & multiple microbatch
|
# Test manual v_schedule with multiple microbatch
|
||||||
def test_run_fwd_bwd_iter_input(
|
def run_fwd_bwd_iter_input(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
port: int,
|
port: int,
|
||||||
|
@ -474,8 +474,8 @@ def test_run_fwd_bwd_iter_input(
|
||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||||
|
|
||||||
|
|
||||||
# T
|
# Test v_schedule generated by graph with multiple microbatch
|
||||||
def test_run_fwd_bwd_with_vschedule(
|
def run_fwd_bwd_with_vschedule(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
port: int,
|
port: int,
|
||||||
|
@ -623,7 +623,7 @@ def test_run_fwd_bwd_with_vschedule(
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
|
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
|
||||||
spawn(
|
spawn(
|
||||||
test_run_fwd_bwd_with_vschedule,
|
run_fwd_bwd_with_vschedule,
|
||||||
nprocs=4,
|
nprocs=4,
|
||||||
num_microbatch=num_microbatch,
|
num_microbatch=num_microbatch,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|
Loading…
Reference in New Issue