ColossalAI/colossalai/utils/cuda.py

57 lines
1.4 KiB
Python
Raw Normal View History

2021-10-28 16:21:23 +00:00
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
2021-10-28 16:21:23 +00:00
import torch
import torch.distributed as dist
2021-10-28 16:21:23 +00:00
def set_to_cuda(models):
2022-01-21 02:44:30 +00:00
"""Send model to gpu.
2021-10-28 16:21:23 +00:00
:param models: nn.module or a list of module
2022-01-21 02:44:30 +00:00
"""
2021-10-28 16:21:23 +00:00
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
2022-01-21 02:44:30 +00:00
"""
2021-10-28 16:21:23 +00:00
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
2021-10-28 16:21:23 +00:00
else:
return torch.device("cpu")
2021-10-28 16:21:23 +00:00
def synchronize():
2022-01-21 02:44:30 +00:00
"""Similar to cuda.synchronize().
2021-10-28 16:21:23 +00:00
Waits for all kernels in all streams on a CUDA device to complete.
2022-01-21 02:44:30 +00:00
"""
2021-10-28 16:21:23 +00:00
if torch.cuda.is_available():
torch.cuda.synchronize()
def empty_cache():
2022-01-21 02:44:30 +00:00
"""Similar to cuda.empty_cache()
2021-10-28 16:21:23 +00:00
Releases all unoccupied cached memory currently held by the caching allocator.
2022-01-21 02:44:30 +00:00
"""
2021-10-28 16:21:23 +00:00
if torch.cuda.is_available():
torch.cuda.empty_cache()
def set_device(index: Optional[int] = None) -> None:
if index is None:
index = dist.get_rank() % torch.cuda.device_count()
torch.cuda.set_device(index)