From 9ae9e74017c16df1d7686b7a8b276631f92032fe Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Fri, 6 Jan 2023 15:59:06 +0800 Subject: [PATCH] fix diff device in some partition --- colossalai/pipeline/rpc/_pipeline_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index 2a7998c14..4739cdaa9 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -789,6 +789,8 @@ class WorkerBase(ABC): args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(), process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device, + process_types=torch.device) # change devices from last stage to current device args, kwargs = data_process_func(args_kwargs)