From 6ff2026254ba6117738c214766f7b89ef4e0b0e6 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 27 May 2024 12:06:24 +0000 Subject: [PATCH] update --- torch_scatter/composite/logsumexp.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/torch_scatter/composite/logsumexp.py b/torch_scatter/composite/logsumexp.py index 8ef77375..69dc90dd 100644 --- a/torch_scatter/composite/logsumexp.py +++ b/torch_scatter/composite/logsumexp.py @@ -5,12 +5,14 @@ from torch_scatter.utils import broadcast -def scatter_logsumexp(src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, - eps: float = 1e-12) -> torch.Tensor: +def scatter_logsumexp( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + eps: float = 1e-12, +) -> torch.Tensor: if not torch.is_floating_point(src): raise ValueError('`scatter_logsumexp` can only be computed over ' 'tensors with floating point data types.') @@ -48,6 +50,7 @@ def scatter_logsumexp(src: torch.Tensor, if orig_out is None: return out.nan_to_num_(neginf=0.0) - else: - mask = ~out.isfinite() - out[mask] = orig_out[mask] + + mask = ~out.isfinite() + out[mask] = orig_out[mask] + return out