Skip to content

Commit

Permalink
Add match_all method in paradigm to support CompoundDataset evaluat…
Browse files Browse the repository at this point in the history
…ion with MNE epochs (#473)

* match_all + test

* fix typing

* correct channel montage

* add comments, fixes test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* linting

* fix test

* list are in reverse order. Sort.

* set -0.5 as a default parameter

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* modify whats_new

---------

Co-authored-by: Gregoire Cattan <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 29, 2023
1 parent c45fb3e commit 3da799d
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Enhancements
- Add :obj:`moabb.datasets.compound_dataset.utils.compound_dataset_list` (:gh:`455` by `Pierre Guetschel`_)
- Add c-VEP paradigm and Thielen2021 c-VEP dataset (:gh:`463` by `Jordy Thielen`_)
- Add option to plot scores vertically. (:gh:`417` by `Sara Sedlar`_)
- Add match_all method in paradigm to support CompoundDataset evaluation with MNE epochs (:gh:`473` by `Gregoire Cattan`_)

Bugs
~~~~
Expand Down
4 changes: 3 additions & 1 deletion moabb/datasets/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def __init__(
paradigm="imagery",
channels=("C3", "Cz", "C4"),
seed=None,
sfreq=128,
):
self.n_runs = n_runs
self.sfreq = sfreq
event_id = {ev: ii + 1 for ii, ev in enumerate(event_list)}
self.channels = channels
self.seed = seed
Expand Down Expand Up @@ -81,7 +83,7 @@ def _get_single_subject_data(self, subject):

def _generate_raw(self):
montage = make_standard_montage("standard_1005")
sfreq = 128
sfreq = self.sfreq
duration = len(self.event_id) * 60
eeg_data = 2e-5 * np.random.randn(duration * sfreq, len(self.channels))
y = np.zeros((duration * sfreq))
Expand Down
38 changes: 38 additions & 0 deletions moabb/paradigms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import FunctionTransformer

from moabb.datasets.base import BaseDataset
from moabb.datasets.bids_interface import StepType
from moabb.datasets.preprocessing import (
EpochsToEvents,
Expand Down Expand Up @@ -420,6 +421,43 @@ def _get_array_pipeline(
return None
return Pipeline(steps)

def match_all(self, datasets: List[BaseDataset], shift=-0.5):
"""
Initialize this paradigm to match all datasets in parameter:
- `self.resample` is set to match the minimum frequency in all datasets, minus `shift`.
If the frequency is 128 for example, then MNE can return 128 or 129 samples
depending on the dataset, even if the length of the epochs is 1s
Setting `shift=-0.5` solves this particular issue.
- `self.channels` is initialized with the channels which are common to all datasets.
Parameters
----------
datasets: List[BaseDataset]
A dataset instance.
"""
resample = None
channels = None
for dataset in datasets:
X, _, _ = self.get_data(
dataset, subjects=[dataset.subject_list[0]], return_epochs=True
)
info = X.info
sfreq = info["sfreq"]
ch_names = info["ch_names"]
# get the minimum sampling frequency between all datasets
resample = sfreq if resample is None else min(resample, sfreq)
# get the channels common to all datasets
channels = (
set(ch_names)
if channels is None
else set(channels).intersection(ch_names)
)
# If resample=128 for example, then MNE can returns 128 or 129 samples
# depending on the dataset, even if the length of the epochs is 1s
# `shift=-0.5` solves this particular issue.
self.resample = resample + shift
self.channels = list(channels)

@abc.abstractmethod
def _get_events_pipeline(self, dataset):
pass
Expand Down
29 changes: 29 additions & 0 deletions moabb/tests/paradigms.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,35 @@ def used_events(self, dataset):


class Test_P300(unittest.TestCase):
def test_match_all(self):
# Note: the match all property is implemented in the base paradigm.
# Thus, although it is located in the P300 section, this test stands for all paradigms.
paradigm = SimpleP300()
dataset1 = FakeDataset(
paradigm="p300",
event_list=["Target", "NonTarget"],
channels=("C3", "Cz", "Fz"),
sfreq=64,
)
dataset2 = FakeDataset(
paradigm="p300",
event_list=["Target", "NonTarget"],
channels=["C3", "C4", "Cz"],
sfreq=256,
)
dataset3 = FakeDataset(
paradigm="p300",
event_list=["Target", "NonTarget"],
channels=["C3", "Cz", "Fz", "C4"],
sfreq=512,
)
shift = -0.5
paradigm.match_all([dataset1, dataset2, dataset3], shift=shift)
# match_all should returns the smallest frequency minus 0.5.
# See comment inside the match_all method
self.assertEqual(paradigm.resample, 64 + shift)
self.assertEqual(paradigm.channels.sort(), ["C3", "Cz"].sort())

def test_BaseP300_paradigm(self):
paradigm = SimpleP300()
dataset = FakeDataset(paradigm="p300", event_list=["Target", "NonTarget"])
Expand Down

0 comments on commit 3da799d

Please sign in to comment.