|
|
@ -60,6 +60,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]): |
|
|
|
zero_stage=stage, |
|
|
|
zero_stage=stage, |
|
|
|
enable_sequence_parallelism=sp_size > 1, |
|
|
|
enable_sequence_parallelism=sp_size > 1, |
|
|
|
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, |
|
|
|
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, |
|
|
|
|
|
|
|
enable_flash_attention=sp_size > 1, |
|
|
|
overlap_communication=False, |
|
|
|
overlap_communication=False, |
|
|
|
initial_scale=1, |
|
|
|
initial_scale=1, |
|
|
|
precision=precision, |
|
|
|
precision=precision, |
|
|
@ -161,7 +162,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]): |
|
|
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) |
|
|
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) |
|
|
|
|
|
|
|
|
|
|
|
# use checkpoint to load sharded zero model |
|
|
|
# use checkpoint to load sharded zero model |
|
|
|
model_dir = "./test_mixtral" |
|
|
|
model_dir = "./test_deepseek" |
|
|
|
if rank == world_size - 1: |
|
|
|
if rank == world_size - 1: |
|
|
|
os.makedirs(model_dir, exist_ok=True) |
|
|
|
os.makedirs(model_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|