|
|
@ -40,7 +40,9 @@ EXAMPLE_MODELS = [
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# bfloat16 cannot represent them exactly
|
|
|
|
# bfloat16 cannot represent them exactly
|
|
|
|
BF16_IGNORED_KEYS = ["masked_bias"]
|
|
|
|
BF16_IGNORED_KEYS = [
|
|
|
|
|
|
|
|
"masked_bias",
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
|
|
|
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
|
|
@ -71,15 +73,9 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
|
|
|
|
@parameterize("model_name", TEST_MODELS)
|
|
|
|
@parameterize("model_name", TEST_MODELS)
|
|
|
|
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
|
|
|
|
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
|
|
|
|
@parameterize("master_weights", [True, False])
|
|
|
|
@parameterize("master_weights", [True, False])
|
|
|
|
@parameterize("max_prefetch", [0, 1, 4])
|
|
|
|
|
|
|
|
@parameterize("enable_async_reduce", [False, True])
|
|
|
|
@parameterize("enable_async_reduce", [False, True])
|
|
|
|
def exam_model_step(
|
|
|
|
def exam_model_step(
|
|
|
|
placement_config,
|
|
|
|
placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, enable_async_reduce=True
|
|
|
|
model_name: str,
|
|
|
|
|
|
|
|
mixed_precision: torch.dtype,
|
|
|
|
|
|
|
|
master_weights: bool,
|
|
|
|
|
|
|
|
max_prefetch: int,
|
|
|
|
|
|
|
|
enable_async_reduce=True,
|
|
|
|
|
|
|
|
):
|
|
|
|
):
|
|
|
|
set_seed(42)
|
|
|
|
set_seed(42)
|
|
|
|
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
|
|
|
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
|
|
@ -108,7 +104,6 @@ def exam_model_step(
|
|
|
|
**placement_config,
|
|
|
|
**placement_config,
|
|
|
|
mixed_precision=mixed_precision,
|
|
|
|
mixed_precision=mixed_precision,
|
|
|
|
master_weights=master_weights,
|
|
|
|
master_weights=master_weights,
|
|
|
|
max_prefetch=max_prefetch,
|
|
|
|
|
|
|
|
enable_async_reduce=enable_async_reduce,
|
|
|
|
enable_async_reduce=enable_async_reduce,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|