mirror of https://github.com/hpcaitech/ColossalAI
fix diff device in some partition
parent
3a15b20421
commit
9ae9e74017
|
@ -789,6 +789,8 @@ class WorkerBase(ABC):
|
|||
|
||||
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device,
|
||||
process_types=torch.device) # change devices from last stage to current device
|
||||
|
||||
args, kwargs = data_process_func(args_kwargs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue