Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add match_all method in paradigm to support CompoundDataset evaluation with MNE epochs #473

Merged
merged 13 commits into from
Aug 29, 2023
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Enhancements
- Add :obj:`moabb.datasets.compound_dataset.utils.compound_dataset_list` (:gh:`455` by `Pierre Guetschel`_)
- Add c-VEP paradigm and Thielen2021 c-VEP dataset (:gh:`463` by `Jordy Thielen`_)
- Add option to plot scores vertically. (:gh:`417` by `Sara Sedlar`_)
- Add match_all method in paradigm to support CompoundDataset evaluation with MNE epochs (:gh:`473` by `Gregoire Cattan`_)

Bugs
~~~~
Expand Down
4 changes: 3 additions & 1 deletion moabb/datasets/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def __init__(
paradigm="imagery",
channels=("C3", "Cz", "C4"),
seed=None,
sfreq=128,
):
self.n_runs = n_runs
self.sfreq = sfreq
event_id = {ev: ii + 1 for ii, ev in enumerate(event_list)}
self.channels = channels
self.seed = seed
Expand Down Expand Up @@ -81,7 +83,7 @@ def _get_single_subject_data(self, subject):

def _generate_raw(self):
montage = make_standard_montage("standard_1005")
sfreq = 128
sfreq = self.sfreq
duration = len(self.event_id) * 60
eeg_data = 2e-5 * np.random.randn(duration * sfreq, len(self.channels))
y = np.zeros((duration * sfreq))
Expand Down
38 changes: 38 additions & 0 deletions moabb/paradigms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import FunctionTransformer

from moabb.datasets.base import BaseDataset
from moabb.datasets.bids_interface import StepType
from moabb.datasets.preprocessing import (
EpochsToEvents,
Expand Down Expand Up @@ -420,6 +421,43 @@ def _get_array_pipeline(
return None
return Pipeline(steps)

def match_all(self, datasets: List[BaseDataset], shift=-0.5):
"""
Initialize this paradigm to match all datasets in parameter:
- `self.resample` is set to match the minimum frequency in all datasets, minus `shift`.
If the frequency is 128 for example, then MNE can return 128 or 129 samples
depending on the dataset, even if the length of the epochs is 1s
Setting `shift=-0.5` solves this particular issue.
- `self.channels` is initialized with the channels which are common to all datasets.

Parameters
----------
datasets: List[BaseDataset]
A dataset instance.
"""
resample = None
channels = None
for dataset in datasets:
X, _, _ = self.get_data(
dataset, subjects=[dataset.subject_list[0]], return_epochs=True
)
info = X.info
sfreq = info["sfreq"]
ch_names = info["ch_names"]
# get the minimum sampling frequency between all datasets
resample = sfreq if resample is None else min(resample, sfreq)
# get the channels common to all datasets
channels = (
set(ch_names)
if channels is None
else set(channels).intersection(ch_names)
)
# If resample=128 for example, then MNE can returns 128 or 129 samples
# depending on the dataset, even if the length of the epochs is 1s
# `shift=-0.5` solves this particular issue.
self.resample = resample + shift
self.channels = list(channels)

@abc.abstractmethod
def _get_events_pipeline(self, dataset):
pass
Expand Down
29 changes: 29 additions & 0 deletions moabb/tests/paradigms.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,35 @@ def used_events(self, dataset):


class Test_P300(unittest.TestCase):
def test_match_all(self):
# Note: the match all property is implemented in the base paradigm.
# Thus, although it is located in the P300 section, this test stands for all paradigms.
paradigm = SimpleP300()
dataset1 = FakeDataset(
paradigm="p300",
event_list=["Target", "NonTarget"],
channels=("C3", "Cz", "Fz"),
sfreq=64,
)
dataset2 = FakeDataset(
paradigm="p300",
event_list=["Target", "NonTarget"],
channels=["C3", "C4", "Cz"],
sfreq=256,
)
dataset3 = FakeDataset(
paradigm="p300",
event_list=["Target", "NonTarget"],
channels=["C3", "Cz", "Fz", "C4"],
sfreq=512,
)
shift = -0.5
paradigm.match_all([dataset1, dataset2, dataset3], shift=shift)
# match_all should returns the smallest frequency minus 0.5.
# See comment inside the match_all method
self.assertEqual(paradigm.resample, 64 + shift)
self.assertEqual(paradigm.channels.sort(), ["C3", "Cz"].sort())

def test_BaseP300_paradigm(self):
paradigm = SimpleP300()
dataset = FakeDataset(paradigm="p300", event_list=["Target", "NonTarget"])
Expand Down