Browse Source

[FP8] unsqueeze scale to make it compatible with torch.compile (#6040)

pull/6042/head
Guangyao Zhang 3 months ago committed by GitHub
parent
commit
e96a0761ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      colossalai/quantization/fp8.py

2
colossalai/quantization/fp8.py

@ -56,7 +56,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
scale_inv = 1.0 / scale
ret = (scale * inp.float()).to(fp8_type)
return ret, scale_inv
return ret, torch.unsqueeze(scale_inv, dim=0)
def cast_from_fp8(

Loading…
Cancel
Save