#!/usr/bin/env python # -*- encoding: utf-8 -*- import torch def set_to_cuda(models): """Send model to gpu. :param models: nn.module or a list of module """ if isinstance(models, list) and len(models) > 1: ret = [] for model in models: ret.append(model.to(get_current_device())) return ret elif isinstance(models, list): return models[0].to(get_current_device()) else: return models.to(get_current_device()) def get_current_device(): """Returns the index of a currently selected device (gpu/cpu). """ if torch.cuda.is_available(): return torch.cuda.current_device() else: return 'cpu' def synchronize(): """Similar to cuda.synchronize(). Waits for all kernels in all streams on a CUDA device to complete. """ if torch.cuda.is_available(): torch.cuda.synchronize() def empty_cache(): """Similar to cuda.empty_cache() Releases all unoccupied cached memory currently held by the caching allocator. """ if torch.cuda.is_available(): torch.cuda.empty_cache()