Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Retain continuous approximation of Poisson in torch v1.8.0 #1351

Merged
merged 9 commits into from
Mar 5, 2021
19 changes: 16 additions & 3 deletions src/pyhf/tensor/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def einsum(self, subscripts, *operands):
return torch.einsum(subscripts, operands)

def poisson_logpdf(self, n, lam):
return torch.distributions.Poisson(lam).log_prob(n)
# validate_args=True disallows continuous approximation
return torch.distributions.Poisson(lam, validate_args=False).log_prob(n)

def poisson(self, n, lam):
r"""
Expand Down Expand Up @@ -347,9 +348,16 @@ def poisson(self, n, lam):
Returns:
PyTorch FloatTensor: Value of the continuous approximation to Poisson(n|lam)
"""
return torch.exp(torch.distributions.Poisson(lam).log_prob(n))
# validate_args=True disallows continuous approximation
return torch.exp(
torch.distributions.Poisson(lam, validate_args=False).log_prob(n)
)

def normal_logpdf(self, x, mu, sigma):
x = self.astensor(x)
mu = self.astensor(mu)
sigma = self.astensor(sigma)
matthewfeickert marked this conversation as resolved.
Show resolved Hide resolved

normal = torch.distributions.Normal(mu, sigma)
return normal.log_prob(x)

Expand Down Expand Up @@ -379,6 +387,10 @@ def normal(self, x, mu, sigma):
Returns:
PyTorch FloatTensor: Value of Normal(x|mu, sigma)
"""
x = self.astensor(x)
mu = self.astensor(mu)
sigma = self.astensor(sigma)

normal = torch.distributions.Normal(mu, sigma)
return self.exp(normal.log_prob(x))

Expand Down Expand Up @@ -433,7 +445,8 @@ def poisson_dist(self, rate):
PyTorch Poisson distribution: The Poisson distribution class

"""
return torch.distributions.Poisson(rate)
# validate_args=True disallows continuous approximation
return torch.distributions.Poisson(rate, validate_args=False)

def normal_dist(self, mu, sigma):
r"""
Expand Down
50 changes: 49 additions & 1 deletion tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def test_shape(backend):
)


@pytest.mark.skip_pytorch
@pytest.mark.skip_pytorch64
def test_pdf_calculations(backend):
tb = pyhf.tensorlib
assert tb.tolist(tb.normal_cdf(tb.astensor([0.8]))) == pytest.approx(
Expand Down Expand Up @@ -286,7 +288,53 @@ def test_pdf_calculations(backend):

# Ensure continuous approximation is valid
assert tb.tolist(
tb.poisson(tb.astensor([0.5, 1.1, 1.5]), tb.astensor(1.0))
tb.poisson(n=tb.astensor([0.5, 1.1, 1.5]), lam=tb.astensor(1.0))
) == pytest.approx([0.4151074974205947, 0.3515379040027489, 0.2767383316137298])


# validate_args in torch.distributions raises ValueError not nan
@pytest.mark.only_pytorch
@pytest.mark.only_pytorch64
def test_pdf_calculations_pytorch(backend):
tb = pyhf.tensorlib

values = tb.astensor([0, 0, 1, 1])
mus = tb.astensor([0, 1, 0, 1])
sigmas = tb.astensor([0, 0, 0, 0])
for x, mu, sigma in zip(values, mus, sigmas):
with pytest.raises(ValueError):
_ = tb.normal_logpdf(x, mu, sigma)
assert tb.tolist(
tb.normal_logpdf(
tb.astensor([0, 0, 1, 1]),
tb.astensor([0, 1, 0, 1]),
tb.astensor([1, 1, 1, 1]),
)
) == pytest.approx(
[
-0.91893853,
-1.41893853,
-1.41893853,
-0.91893853,
],
)

# poisson(lambda=0) is not defined, should return NaN
assert tb.tolist(
tb.poisson(tb.astensor([0, 0, 1, 1]), tb.astensor([0, 1, 0, 1]))
) == pytest.approx(
[np.nan, 0.3678794503211975, 0.0, 0.3678794503211975], nan_ok=True
)
assert tb.tolist(
tb.poisson_logpdf(tb.astensor([0, 0, 1, 1]), tb.astensor([0, 1, 0, 1]))
) == pytest.approx(
np.log([np.nan, 0.3678794503211975, 0.0, 0.3678794503211975]).tolist(),
nan_ok=True,
)

# Ensure continuous approximation is valid
assert tb.tolist(
tb.poisson(n=tb.astensor([0.5, 1.1, 1.5]), lam=tb.astensor(1.0))
) == pytest.approx([0.4151074974205947, 0.3515379040027489, 0.2767383316137298])


Expand Down