mirror of https://github.com/hpcaitech/ColossalAI
13 lines
390 B
Python
13 lines
390 B
Python
|
import torch
|
||
|
from typing import Union, Optional
|
||
|
from colossalai.tensor import ColoTensor
|
||
|
|
||
|
GeneralTensor = Union[ColoTensor, torch.Tensor]
|
||
|
Number = Union[int, float]
|
||
|
|
||
|
|
||
|
def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTensor]:
|
||
|
if tensor is not None and not isinstance(tensor, ColoTensor):
|
||
|
tensor = ColoTensor.from_torch_tensor(tensor)
|
||
|
return tensor
|