Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed May 27, 2024
1 parent a8bf8b7 commit 6ff2026
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions torch_scatter/composite/logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from torch_scatter.utils import broadcast


def scatter_logsumexp(src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
eps: float = 1e-12) -> torch.Tensor:
def scatter_logsumexp(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
eps: float = 1e-12,
) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
Expand Down Expand Up @@ -48,6 +50,7 @@ def scatter_logsumexp(src: torch.Tensor,

if orig_out is None:
return out.nan_to_num_(neginf=0.0)
else:
mask = ~out.isfinite()
out[mask] = orig_out[mask]

mask = ~out.isfinite()
out[mask] = orig_out[mask]
return out

0 comments on commit 6ff2026

Please sign in to comment.