mirror of https://github.com/hpcaitech/ColossalAI
[FP8] unsqueeze scale to make it compatible with torch.compile (#6040)
parent
0d3a85d04f
commit
e96a0761ea
|
@ -56,7 +56,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
|
|||
scale_inv = 1.0 / scale
|
||||
|
||||
ret = (scale * inp.float()).to(fp8_type)
|
||||
return ret, scale_inv
|
||||
return ret, torch.unsqueeze(scale_inv, dim=0)
|
||||
|
||||
|
||||
def cast_from_fp8(
|
||||
|
|
Loading…
Reference in New Issue