2021-10-28 16:21:23 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
2023-09-18 08:31:06 +00:00
|
|
|
from typing import Optional
|
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
import torch
|
2023-09-18 08:31:06 +00:00
|
|
|
import torch.distributed as dist
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
def set_to_cuda(models):
|
2022-01-21 02:44:30 +00:00
|
|
|
"""Send model to gpu.
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
:param models: nn.module or a list of module
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
if isinstance(models, list) and len(models) > 1:
|
|
|
|
ret = []
|
|
|
|
for model in models:
|
|
|
|
ret.append(model.to(get_current_device()))
|
|
|
|
return ret
|
|
|
|
elif isinstance(models, list):
|
|
|
|
return models[0].to(get_current_device())
|
|
|
|
else:
|
|
|
|
return models.to(get_current_device())
|
|
|
|
|
|
|
|
|
2022-04-11 08:47:57 +00:00
|
|
|
def get_current_device() -> torch.device:
|
|
|
|
"""
|
|
|
|
Returns currently selected device (gpu/cpu).
|
2023-09-18 08:31:06 +00:00
|
|
|
If cuda available, return gpu, otherwise return cpu.
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
if torch.cuda.is_available():
|
2023-09-19 06:20:26 +00:00
|
|
|
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
2021-10-28 16:21:23 +00:00
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
return torch.device("cpu")
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
def synchronize():
|
2022-01-21 02:44:30 +00:00
|
|
|
"""Similar to cuda.synchronize().
|
2021-10-28 16:21:23 +00:00
|
|
|
Waits for all kernels in all streams on a CUDA device to complete.
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
def empty_cache():
|
2022-01-21 02:44:30 +00:00
|
|
|
"""Similar to cuda.empty_cache()
|
2021-10-28 16:21:23 +00:00
|
|
|
Releases all unoccupied cached memory currently held by the caching allocator.
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.cuda.empty_cache()
|
2023-09-18 08:31:06 +00:00
|
|
|
|
|
|
|
|
|
|
|
def set_device(index: Optional[int] = None) -> None:
|
|
|
|
if index is None:
|
|
|
|
index = dist.get_rank() % torch.cuda.device_count()
|
|
|
|
torch.cuda.set_device(index)
|