You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import torch
import torch_xla as xla
device = xla.device(0)
def foo(a, b):
y = torch.mul(a, b)
return y
a = torch.ones([5, 9216, 64], dtype=torch.bfloat16, device=device)
b = torch.ones([5, 9216, 64], dtype=torch.bfloat16, device=device)
y = foo(a, b)
print(y)
I was able to achieve bf16 multiplication by setting export XLA_USE_BF16=1, but I received the following warning
XLA_USE_BF16 will be deprecated after the 2.5 release, please convert your model to bf16 directly
I'm not sure how I can enable bf16 multiplication in HLO (High-Level Optimizer) in the correct way, without using the deprecated flag.
The text was updated successfully, but these errors were encountered:
apivovarov
changed the title
[GPU][BF16] torch.mul is lowered to hlo as f32 multiply
[GPU][BF16] torch.mul is lowered to HLO as an f32 multiply
Jan 8, 2025
apivovarov
changed the title
[GPU][BF16] torch.mul is lowered to HLO as an f32 multiply
[Q][GPU][BF16] torch.mul is lowered to HLO as an f32 multiply
Jan 8, 2025
❓ Questions and Help
torch 2.5.1
torch_xla 2.5.1
cuda 12.4
GPU NVIDIA L4
The following example uses
torch.mul
where both operands are bf16, but in the HLO graph, I see an f32 multiply operation.hlo: module_0000.SyncTensorsGraph.16.before_optimizations.txt
I was able to achieve bf16 multiplication by setting
export XLA_USE_BF16=1
, but I received the following warningI'm not sure how I can enable bf16 multiplication in HLO (High-Level Optimizer) in the correct way, without using the deprecated flag.
The text was updated successfully, but these errors were encountered: