diff --git a/colossalai/utils/memory_tracer/async_memtracer.py b/colossalai/utils/memory_tracer/async_memtracer.py index fe65651ae..842aafbdd 100644 --- a/colossalai/utils/memory_tracer/async_memtracer.py +++ b/colossalai/utils/memory_tracer/async_memtracer.py @@ -52,7 +52,12 @@ class AsyncMemoryMonitor: def __init__(self, power: int = 10): self.keep_measuring = False - self.executor = ThreadPoolExecutor(max_workers=1) + + current_device = get_current_device() + def _set_cuda_device(): + torch.cuda.set_device(current_device) + + self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) self.monitor_thread = None self.interval = 1 / (10**power) self.time_stamps = []