-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Use JAX v0.2.4+ API #1134
fix: Use JAX v0.2.4+ API #1134
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1134 +/- ##
==========================================
+ Coverage 96.99% 97.02% +0.02%
==========================================
Files 62 62
Lines 3597 3598 +1
Branches 519 519
==========================================
+ Hits 3489 3491 +2
+ Misses 67 66 -1
Partials 41 41
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
I also opened up an Issue on JAX RE: the docs of |
@alexander-held heads up on this. |
2490342
to
6041200
Compare
src/pyhf/tensor/jax_backend.py
Outdated
@@ -17,8 +18,9 @@ def __init__(self, rate): | |||
self.rate = rate | |||
|
|||
def sample(self, sample_shape): | |||
return poisson.osp_stats.poisson(self.rate).rvs( | |||
size=sample_shape + self.rate.shape | |||
tensorlib = jax_backend() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Getting the tensorlib
this way is how we've done thing previously inhere, but not sure if there's a smarter way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isn't the best, since they might have some configurations on the jax backend that are different than the default you're relying on here.
Description
With release
v0.2.4
of JAX the underlying stats API has changed as there is no longer ajax.scipy.stats.norm.osp_stats
orjax.scipy.stats.poisson.osp_stats
API, but instead JAX now just uses the SciPy API directly withAs such, this requires a minimum version bump of
jax
andjaxlib
and then also requires changing the imports to use SciPy directly. Given that the continuous approximation to the Poisson is used there is actually no need to importjax.scipy.stats.poisson
anymore either.Checklist Before Requesting Reviewer
Before Merging
For the PR Assignees: