mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
13 lines
390 B
13 lines
390 B
import torch
|
|
from typing import Union, Optional
|
|
from colossalai.tensor import ColoTensor
|
|
|
|
GeneralTensor = Union[ColoTensor, torch.Tensor]
|
|
Number = Union[int, float]
|
|
|
|
|
|
def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTensor]:
|
|
if tensor is not None and not isinstance(tensor, ColoTensor):
|
|
tensor = ColoTensor.from_torch_tensor(tensor)
|
|
return tensor
|