diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index 2a7998c14..4739cdaa9 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -789,6 +789,8 @@ class WorkerBase(ABC): args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(), process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device, + process_types=torch.device) # change devices from last stage to current device args, kwargs = data_process_func(args_kwargs)