Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Use JAX v0.2.4+ API #1134

Merged
merged 7 commits into from
Oct 20, 2020
Merged

fix: Use JAX v0.2.4+ API #1134

merged 7 commits into from
Oct 20, 2020

Conversation

matthewfeickert
Copy link
Member

@matthewfeickert matthewfeickert commented Oct 20, 2020

Description

With release v0.2.4 of JAX the underlying stats API has changed as there is no longer a jax.scipy.stats.norm.osp_stats or jax.scipy.stats.poisson.osp_stats API, but instead JAX now just uses the SciPy API directly with

import scipy.stats as osp_stats

As such, this requires a minimum version bump of jax and jaxlib and then also requires changing the imports to use SciPy directly. Given that the continuous approximation to the Poisson is used there is actually no need to import jax.scipy.stats.poisson anymore either.

Checklist Before Requesting Reviewer

  • Tests are passing
  • "WIP" removed from the title of the pull request
  • Selected an Assignee for the PR to be responsible for the log summary

Before Merging

For the PR Assignees:

  • Summarize commit messages into a comprehensive review of the PR
* Use JAX v0.2.4 API with regards to using scipy.stats directly
* Update to jax v0.2.4+ and jaxlib v0.1.56+
   - Restrict to jax v0.2.X and jaxlib v0.1.X to ensure stability
* Add tests for Poisson and Normal sample shape

@matthewfeickert matthewfeickert added the fix A bug fix label Oct 20, 2020
@matthewfeickert matthewfeickert self-assigned this Oct 20, 2020
@codecov
Copy link

codecov bot commented Oct 20, 2020

Codecov Report

Merging #1134 into master will increase coverage by 0.02%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1134      +/-   ##
==========================================
+ Coverage   96.99%   97.02%   +0.02%     
==========================================
  Files          62       62              
  Lines        3597     3598       +1     
  Branches      519      519              
==========================================
+ Hits         3489     3491       +2     
+ Misses         67       66       -1     
  Partials       41       41              
Flag Coverage Δ
#unittests 97.02% <100.00%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
src/pyhf/tensor/jax_backend.py 96.74% <100.00%> (+0.84%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update bf0b1d2...a2a0258. Read the comment docs.

@matthewfeickert matthewfeickert marked this pull request as ready for review October 20, 2020 04:44
@matthewfeickert matthewfeickert added the build Changes that affect the build system or external dependencies label Oct 20, 2020
@matthewfeickert
Copy link
Member Author

I also opened up an Issue on JAX RE: the docs of jax.scipy.stats.poisson: jax-ml/jax#4648

@matthewfeickert
Copy link
Member Author

@alexander-held heads up on this.

@matthewfeickert matthewfeickert force-pushed the fix/update-to-jax-v0.2.4-API branch from 2490342 to 6041200 Compare October 20, 2020 14:37
@@ -17,8 +18,9 @@ def __init__(self, rate):
self.rate = rate

def sample(self, sample_shape):
return poisson.osp_stats.poisson(self.rate).rvs(
size=sample_shape + self.rate.shape
tensorlib = jax_backend()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting the tensorlib this way is how we've done thing previously inhere, but not sure if there's a smarter way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't the best, since they might have some configurations on the jax backend that are different than the default you're relying on here.

@kratsg kratsg merged commit 4d64ef1 into master Oct 20, 2020
@kratsg kratsg deleted the fix/update-to-jax-v0.2.4-API branch October 20, 2020 17:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build Changes that affect the build system or external dependencies fix A bug fix tests pytest
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants