fix diff device in some partition

pull/2373/head
Ziyue Jiang 2023-01-06 15:59:06 +08:00
parent 3a15b20421
commit 9ae9e74017
1 changed files with 2 additions and 0 deletions

View File

@ -789,6 +789,8 @@ class WorkerBase(ABC):
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device,
process_types=torch.device) # change devices from last stage to current device
args, kwargs = data_process_func(args_kwargs)