You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/utils/cuda.py

46 lines
1.1 KiB

3 years ago
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
def set_to_cuda(models):
"""Send model to gpu.
3 years ago
:param models: nn.module or a list of module
"""
3 years ago
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device():
"""Returns the index of a currently selected device (gpu/cpu).
"""
3 years ago
if torch.cuda.is_available():
return torch.cuda.current_device()
else:
return 'cpu'
def synchronize():
"""Similar to cuda.synchronize().
3 years ago
Waits for all kernels in all streams on a CUDA device to complete.
"""
3 years ago
if torch.cuda.is_available():
torch.cuda.synchronize()
def empty_cache():
"""Similar to cuda.empty_cache()
3 years ago
Releases all unoccupied cached memory currently held by the caching allocator.
"""
3 years ago
if torch.cuda.is_available():
torch.cuda.empty_cache()