-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding yield tables compatible with tabulate package * yield tables integrated in visualize.data_MC * adding tabulate as core dependency * model_utils.calculate_stdev returns list instead of awkward array
- Loading branch information
1 parent
fc984e4
commit 36035e7
Showing
8 changed files
with
187 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ install_requires = | |
click | ||
awkward1 | ||
scipy | ||
tabulate | ||
|
||
[options.packages.find] | ||
where = src | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import logging | ||
from typing import Any, Dict, List | ||
|
||
import numpy as np | ||
import pyhf | ||
import tabulate | ||
|
||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def _header_name(channel_name: str, i_bin: int) -> str: | ||
"""Constructs the header name for a column in a yield table. | ||
Args: | ||
channel_name (str): name of the channel (phase space region) | ||
i_bin (int): index of bin in channel | ||
Returns: | ||
str: the header name to be used for the column | ||
""" | ||
if i_bin == 0: | ||
header_name = f"{channel_name}\nbin {i_bin+1}" | ||
else: | ||
header_name = f"\nbin {i_bin+1}" | ||
return header_name | ||
|
||
|
||
def _yields( | ||
model: pyhf.pdf.Model, | ||
model_yields: List[np.ndarray], | ||
total_stdev_model: List[List[float]], | ||
data: List[np.ndarray], | ||
) -> List[Dict[str, Any]]: | ||
"""Outputs and returns a yield table with predicted and observed yields per bin. | ||
Args: | ||
model (pyhf.pdf.Model): the model which the table corresponds to | ||
model_yields (List[np.ndarray]): yields per channel, sample, and bin | ||
total_stdev_model (List[List[float]]): total model standard deviation per | ||
channel and bin | ||
data (List[np.ndarray]): data yield per channel and bin | ||
Returns: | ||
List[Dict[str, Any]]: yield table for use with the ``tabulate`` package | ||
""" | ||
table = [] # table containing all yields | ||
|
||
# rows for each individual sample | ||
for i_sam, sample_name in enumerate(model.config.samples): | ||
sample_dict = {"sample": sample_name} # one dict per sample | ||
for i_chan, channel_name in enumerate(model.config.channels): | ||
for i_bin in range(model.config.channel_nbins[channel_name]): | ||
sample_dict.update( | ||
{ | ||
_header_name( | ||
channel_name, i_bin | ||
): f"{model_yields[i_chan][i_sam][i_bin]:.2f}" | ||
} | ||
) | ||
table.append(sample_dict) | ||
|
||
# dicts for total model prediction and data | ||
total_dict = {"sample": "total"} | ||
data_dict = {"sample": "data"} | ||
for i_chan, channel_name in enumerate(model.config.channels): | ||
total_model = np.sum(model_yields[i_chan], axis=0) # sum over samples | ||
for i_bin in range(model.config.channel_nbins[channel_name]): | ||
total_dict.update( | ||
{ | ||
_header_name(channel_name, i_bin): f"{total_model[i_bin]:.2f} " | ||
f"\u00B1 {total_stdev_model[i_chan][i_bin]:.2f}" | ||
} | ||
) | ||
data_dict.update( | ||
{_header_name(channel_name, i_bin): f"{data[i_chan][i_bin]:.2f}"} | ||
) | ||
table += [total_dict, data_dict] | ||
|
||
log.info( | ||
"yield table:\n" | ||
+ tabulate.tabulate(table, headers="keys", tablefmt="fancy_grid") | ||
) | ||
|
||
return table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import numpy as np | ||
import pyhf | ||
import pytest | ||
|
||
from cabinetry import tabulate | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_input, expected", | ||
[ | ||
(("abc", 0), "abc\nbin 1"), | ||
(("abc", 2), "\nbin 3"), | ||
], | ||
) | ||
def test__header_name(test_input, expected): | ||
assert tabulate._header_name(*test_input) == expected | ||
|
||
|
||
def test__yields(example_spec_multibin, example_spec_with_background): | ||
# multiple channels | ||
model = pyhf.Workspace(example_spec_multibin).model() | ||
yields = [np.asarray([[25.0, 5.0]]), np.asarray([[8.0]])] | ||
total_stdev = [[5.0, 2.0], [1.0]] | ||
data = [np.asarray([35, 8]), np.asarray([10])] | ||
|
||
yield_table = tabulate._yields(model, yields, total_stdev, data) | ||
assert yield_table == [ | ||
{ | ||
"sample": "Signal", | ||
"region_1\nbin 1": "25.00", | ||
"\nbin 2": "5.00", | ||
"region_2\nbin 1": "8.00", | ||
}, | ||
{ | ||
"sample": "total", | ||
"region_1\nbin 1": "25.00 \u00B1 5.00", | ||
"\nbin 2": "5.00 \u00B1 2.00", | ||
"region_2\nbin 1": "8.00 \u00B1 1.00", | ||
}, | ||
{ | ||
"sample": "data", | ||
"region_1\nbin 1": "35.00", | ||
"\nbin 2": "8.00", | ||
"region_2\nbin 1": "10.00", | ||
}, | ||
] | ||
|
||
# multiple samples | ||
model = pyhf.Workspace(example_spec_with_background).model() | ||
yields = [np.asarray([[150.0], [50.0]])] | ||
total_stdev = [[8.60]] | ||
data = [np.asarray([160])] | ||
|
||
yield_table = tabulate._yields(model, yields, total_stdev, data) | ||
assert yield_table == [ | ||
{"sample": "Background", "Signal Region\nbin 1": "150.00"}, | ||
{"sample": "Signal", "Signal Region\nbin 1": "50.00"}, | ||
{"sample": "total", "Signal Region\nbin 1": "200.00 \u00B1 8.60"}, | ||
{"sample": "data", "Signal Region\nbin 1": "160.00"}, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters