import torch
from torch import Tensor

from .base import MixedPrecisionMixin


class BF16MixedPrecisionMixin(MixedPrecisionMixin):
    dtype = torch.bfloat16

    def pre_backward(self, loss: Tensor) -> Tensor:
        return loss

    def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
        return grad

    def should_skip_step(self) -> bool:
        return False

    def pre_zero_grad(self) -> None:
        pass

    def get_grad_div_scale(self) -> float:
        return 1.0