Skip to content

Commit

Permalink
update fixed_poi_fit() to call fit() essentially
Browse files Browse the repository at this point in the history
  • Loading branch information
kratsg committed Sep 5, 2020
1 parent 25541c0 commit d465637
Showing 1 changed file with 8 additions and 24 deletions.
32 changes: 8 additions & 24 deletions src/pyhf/infer/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def twice_nll(pars, data, pdf):

def fit(data, pdf, init_pars=None, par_bounds=None, fixed_params=None, **kwargs):
r"""
Run a unconstrained maximum likelihood fit.
Run a maximum likelihood fit.
This is done by minimizing the objective function :func:`~pyhf.infer.mle.twice_nll`
of the model parameters given the observed data.
This is used to produce the maximal likelihood :math:`L\left(\hat{\mu}, \hat{\boldsymbol{\theta}}\right)`
Expand Down Expand Up @@ -87,6 +87,7 @@ def fit(data, pdf, init_pars=None, par_bounds=None, fixed_params=None, **kwargs)
pdf (~pyhf.pdf.Model): The statistical model adhering to the schema model.json
init_pars (`list`): Values to initialize the model parameters at for the fit
par_bounds (`list` of `list`\s or `tuple`\s): The extrema of values the model parameters are allowed to reach in the fit
fixed_params (`list`): Parameters to be held constant in the fit.
kwargs: Keyword arguments passed through to the optimizer API
Returns:
Expand Down Expand Up @@ -155,6 +156,7 @@ def fixed_poi_fit(
pdf (~pyhf.pdf.Model): The statistical model adhering to the schema model.json
init_pars (`list`): Values to initialize the model parameters at for the fit
par_bounds (`list` of `list`\s or `tuple`\s): The extrema of values the model parameters are allowed to reach in the fit
fixed_params (`list`): Parameters to be held constant in the fit.
kwargs: Keyword arguments passed through to the optimizer API
Returns:
Expand All @@ -166,28 +168,10 @@ def fixed_poi_fit(
'No POI is defined. A POI is required to fit with a fixed POI.'
)

_, opt = get_backend()
init_pars = init_pars or pdf.config.suggested_init()
par_bounds = par_bounds or pdf.config.suggested_bounds()
fixed_params = fixed_params or pdf.config.suggested_fixed()
init_pars = [*(init_pars or pdf.config.suggested_init())]
fixed_params = [*(fixed_params or pdf.config.suggested_fixed())]

# get fixed vals from the model
fixed_vals = [
(index, init)
for index, (init, is_fixed) in enumerate(zip(init_pars, fixed_params))
if is_fixed
]
# add the fixed POI
fixed_vals = fixed_vals + [(pdf.config.poi_index, poi_val)]
# de-dupe and use last-appended result for each index
fixed_vals = list(dict(fixed_vals).items())
init_pars[pdf.config.poi_index] = poi_val
fixed_params[pdf.config.poi_index] = True

return opt.minimize(
twice_nll,
data,
pdf,
init_pars,
par_bounds,
fixed_vals,
**kwargs,
)
return fit(data, pdf, init_pars, par_bounds, fixed_params, **kwargs)

0 comments on commit d465637

Please sign in to comment.