[bug] fixed broken test_found_inf (#725)

pull/724/head^2
Frank Lee 2022-04-11 22:00:27 +08:00 committed by GitHub
parent 193dc8dacb
commit 20ab1f5520
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -31,7 +31,7 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(get_current_device()),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = model_builder(checkpoint=True)