Skip to content

Commit

Permalink
Feat/interpolate channels (#480)
Browse files Browse the repository at this point in the history
* push test for channels interpolation

* use mne interpolation

* fix typing

* fix get_paradigm

* typo

* constructor parameter `paradigm` can now be removed

* fix filtered dataset

* set interpolate_missing_channels to False by default

* fix assert paradigm not passing

* parameter incorreclty named

* fix paradigm name

* exclude stim channel

* debug trace

* A-B vs B-A set difference pb

* fix event_list

* fix montage error
fix event list in tests

* fix disabling montage

* pop should be on info not raw object

* typo

* workaround ValueError: lowpass frequency 32.0 must be less than Nyquist (32.0)

* do not forget to pick channels

* remove fmin/fmax

* fix Nyquist  error

* remove finnally block

* log montage

* use default montage 1005 is not available

* fix type

* add origin

* fix reference error

* debug

* add new check on epochs length

* dataset missing in get_data

* invalid subject given

* debug string

* missing dereferencement

* fix destructuring

* some debug trace

* debug string

* debug string

* additiona testing

* warn epochs

* debug info

* fix syntax

* inverse test

* raw.copy missing

* resample not accurate. Debug.

* fix interpolate missing channel True when should be False

* get_data directly from datasets

* lint

* Update bi_illiteracy.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* missing modification

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update whats_new.rst

---------

Co-authored-by: Gregoire Cattan <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Bru <[email protected]>
  • Loading branch information
4 people authored Nov 10, 2023
1 parent bce0f9f commit e7f465c
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 29 deletions.
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Enhancements
~~~~~~~~~~~~

- Adding cache option to the evaluation (:gh:`517` by `Bruno Aristimunha`_)
- Option to interpolate channel in paradigms' `match_all` method (:gh:`480` by `Gregoire Cattan`_)

Bugs
~~~~
Expand Down
2 changes: 1 addition & 1 deletion moabb/datasets/compound_dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
BI_Il,
Cattan2019_VR_Il,
)
from .utils import _init_compound_dataset_list
from .utils import _init_compound_dataset_list, compound # noqa: F401


_init_compound_dataset_list()
Expand Down
26 changes: 23 additions & 3 deletions moabb/datasets/compound_dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@ class CompoundDataset(BaseDataset):
interval: list with 2 entries
See `BaseDataset`.
paradigm: ['p300','imagery', 'ssvep', 'rstate']
Defines what sort of dataset this is
"""

def __init__(self, subjects_list: list, code: str, interval: list, paradigm: str):
def __init__(self, subjects_list: list, code: str, interval: list):
self._set_subjects_list(subjects_list)
dataset, _, _, _ = self.subjects_list[0]
paradigm = self._get_paradigm()
super().__init__(
subjects=list(range(1, self.count + 1)),
sessions_per_subject=self._get_sessions_per_subject(),
Expand All @@ -52,6 +51,17 @@ def __init__(self, subjects_list: list, code: str, interval: list, paradigm: str
paradigm=paradigm,
)

@property
def datasets(self):
all_datasets = [entry[0] for entry in self.subjects_list]
found_flags = set()
filtered_dataset = []
for dataset in all_datasets:
if dataset.code not in found_flags:
filtered_dataset.append(dataset)
found_flags.add(dataset.code)
return filtered_dataset

@property
def count(self):
return len(self.subjects_list)
Expand All @@ -78,6 +88,16 @@ def _set_subjects_list(self, subjects_list: list):
for compoundDataset in subjects_list:
self.subjects_list.extend(compoundDataset.subjects_list)

def _get_paradigm(self):
dataset, _, _, _ = self.subjects_list[0]
paradigm = dataset.paradigm
# Check all of the datasets have the same paradigm
for i in range(1, len(self.subjects_list)):
entry = self.subjects_list[i]
dataset = entry[0]
assert dataset.paradigm == paradigm
return paradigm

def _with_data_origin(self, data: dict, shopped_subject):
data_origin = self.subjects_list[shopped_subject - 1]

Expand Down
1 change: 0 additions & 1 deletion moabb/datasets/compound_dataset/bi_illiteracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(self, subjects_list, dataset=None, code=None):
subjects_list=subjects_list,
code=code,
interval=[0, 1.0],
paradigm="p300",
)


Expand Down
15 changes: 15 additions & 0 deletions moabb/datasets/compound_dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect
from typing import List

import moabb.datasets.compound_dataset as db
from moabb.datasets.base import BaseDataset
from moabb.datasets.compound_dataset.base import CompoundDataset


Expand All @@ -11,3 +13,16 @@ def _init_compound_dataset_list():
for ds in inspect.getmembers(db, inspect.isclass):
if issubclass(ds[1], CompoundDataset) and not ds[0] == "CompoundDataset":
compound_dataset_list.append(ds[1])


def compound(*datasets: List[BaseDataset], interval=[0, 1.0]):
subjects_list = [
(d, subject, None, None) for d in datasets for subject in d.subject_list
]
code = "".join([d.code for d in datasets])
ret = CompoundDataset(
subjects_list=subjects_list,
code=code,
interval=interval,
)
return ret
37 changes: 36 additions & 1 deletion moabb/datasets/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
from operator import methodcaller
from typing import Dict, List, Tuple, Union
from warnings import warn

import mne
import numpy as np
Expand Down Expand Up @@ -199,13 +200,15 @@ def __init__(
tmax: float,
baseline: Tuple[float, float],
channels: List[str] = None,
interpolate_missing_channels: bool = False,
):
assert isinstance(event_id, dict) # not None
self.event_id = event_id
self.tmin = tmin
self.tmax = tmax
self.baseline = baseline
self.channels = channels
self.interpolate_missing_channels = interpolate_missing_channels

def transform(self, X, y=None):
raw = X["raw"]
Expand All @@ -218,9 +221,40 @@ def transform(self, X, y=None):
if self.channels is None:
picks = mne.pick_types(raw.info, eeg=True, stim=False)
else:
available_channels = raw.info["ch_names"]
if self.interpolate_missing_channels:
missing_channels = list(set(self.channels).difference(available_channels))

# add missing channels (contains only zeros by default)
try:
raw.add_reference_channels(missing_channels)
except IndexError:
# Index error can occurs if the channels we add are not part of this epoch montage
# Then log a warning
montage = raw.info["dig"]
warn(
f"Montage disabled as one of these channels, {missing_channels}, is not part of the montage {montage}"
)
# and disable the montage
raw.info.pop("dig")
# run again with montage disabled
raw.add_reference_channels(missing_channels)

# Trick: mark these channels as bad
raw.info["bads"].extend(missing_channels)
# ...and use mne bad channel interpolation to generate the value of the missing channels
try:
raw.interpolate_bads(origin="auto")
except ValueError:
# use default origin if montage info not available
raw.interpolate_bads(origin=(0, 0, 0.04))
# update the name of the available channels
available_channels = self.channels

picks = mne.pick_channels(
raw.info["ch_names"], include=self.channels, ordered=True
available_channels, include=self.channels, ordered=True
)
assert len(picks) == len(self.channels)

epochs = mne.Epochs(
raw,
Expand All @@ -236,6 +270,7 @@ def transform(self, X, y=None):
event_repeated="drop",
on_missing="ignore",
)
warn(f"warnEpochs {epochs}")
return epochs


Expand Down
49 changes: 38 additions & 11 deletions moabb/paradigms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
self.resample = resample
self.tmin = tmin
self.tmax = tmax
self.interpolate_missing_channels = False

@property
@abc.abstractmethod
Expand Down Expand Up @@ -399,6 +400,7 @@ def _get_epochs_pipeline(self, return_epochs, return_raws, dataset):
tmax=bmax,
baseline=baseline,
channels=self.channels,
interpolate_missing_channels=self.interpolate_missing_channels,
),
),
)
Expand Down Expand Up @@ -429,7 +431,13 @@ def _get_array_pipeline(
return None
return Pipeline(steps)

def match_all(self, datasets: List[BaseDataset], shift=-0.5):
def match_all(
self,
datasets: List[BaseDataset],
shift=-0.5,
channel_merge_strategy: str = "intersect",
ignore=["stim"],
):
"""
Initialize this paradigm to match all datasets in parameter:
- `self.resample` is set to match the minimum frequency in all datasets, minus `shift`.
Expand All @@ -442,29 +450,48 @@ def match_all(self, datasets: List[BaseDataset], shift=-0.5):
----------
datasets: List[BaseDataset]
A dataset instance.
shift: List[BaseDataset]
Shift the sampling frequency by this value
E.g.: if sampling=128 and shift=-0.5, then it returns 127.5 Hz
channel_merge_strategy: str (default: 'intersect')
Accepts two values:
- 'intersect': keep only channels common to all datasets
- 'union': keep all channels from all datasets, removing duplicate
ignore: List[string]
A list of channels to ignore
..versionadded:: 0.6.0
"""
resample = None
channels = None
channels: set = None
for dataset in datasets:
X, _, _ = self.get_data(
dataset, subjects=[dataset.subject_list[0]], return_epochs=True
)
first_subject = dataset.subject_list[0]
data = dataset.get_data(subjects=[first_subject])[first_subject]
first_session = list(data.keys())[0]
session = data[first_session]
first_run = list(session.keys())[0]
X = session[first_run]
info = X.info
sfreq = info["sfreq"]
ch_names = info["ch_names"]
# get the minimum sampling frequency between all datasets
resample = sfreq if resample is None else min(resample, sfreq)
# get the channels common to all datasets
channels = (
set(ch_names)
if channels is None
else set(channels).intersection(ch_names)
)
if channels is None:
channels = set(ch_names)
elif channel_merge_strategy == "intersect":
channels = channels.intersection(ch_names)
self.interpolate_missing_channels = False
else:
channels = channels.union(ch_names)
self.interpolate_missing_channels = True
# If resample=128 for example, then MNE can returns 128 or 129 samples
# depending on the dataset, even if the length of the epochs is 1s
# `shift=-0.5` solves this particular issue.
self.resample = resample + shift
self.channels = list(channels)

# exclude ignored channels
self.channels = list(channels.difference(ignore))

@abc.abstractmethod
def _get_events_pipeline(self, dataset):
Expand Down
15 changes: 8 additions & 7 deletions moabb/tests/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ def test_fake_dataset(self):
subjects_list,
code="CompoundDataset-test",
interval=[0, 1],
paradigm=self.paradigm,
)

data = compound_data.get_data()
Expand Down Expand Up @@ -385,7 +384,6 @@ def test_compound_dataset_composition(self):
subjects_list,
code="CompoundDataset-test",
interval=[0, 1],
paradigm=self.paradigm,
)

# Add it two time to a subjects_list
Expand All @@ -394,9 +392,11 @@ def test_compound_dataset_composition(self):
subjects_list,
code="CompoundDataset-test",
interval=[0, 1],
paradigm=self.paradigm,
)

# Assert there is only one source dataset in the compound dataset
self.assertEqual(len(compound_data.datasets), 1)

# Assert that the coumpouned dataset has two times more subject than the original one.
data = compound_data.get_data()
self.assertEqual(len(data), 2)
Expand All @@ -408,7 +408,7 @@ def test_get_sessions_per_subject(self):
n_runs=self.n_runs,
n_subjects=self.n_subjects,
event_list=["Target", "NonTarget"],
paradigm=self.paradigm,
paradigm=self.ds.paradigm,
)

# Add the two datasets to a CompoundDataset
Expand All @@ -417,9 +417,11 @@ def test_get_sessions_per_subject(self):
subjects_list,
code="CompoundDataset",
interval=[0, 1],
paradigm=self.paradigm,
)

# Assert there are two source datasets (ds and ds2) in the compound dataset
self.assertEqual(len(compound_dataset.datasets), 2)

# Test private method _get_sessions_per_subject returns the minimum number of sessions per subjects
self.assertEqual(compound_dataset._get_sessions_per_subject(), self.n_sessions)

Expand All @@ -430,7 +432,7 @@ def test_event_id_correctly_updated(self):
n_runs=self.n_runs,
n_subjects=self.n_subjects,
event_list=["Target2", "NonTarget2"],
paradigm=self.paradigm,
paradigm=self.ds.paradigm,
)

# Add the two datasets to a CompoundDataset
Expand All @@ -440,7 +442,6 @@ def test_event_id_correctly_updated(self):
subjects_list,
code="CompoundDataset",
interval=[0, 1],
paradigm=self.paradigm,
)

# Check that the event_id of the compound_dataset is the same has the first dataset
Expand Down
36 changes: 36 additions & 0 deletions moabb/tests/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.pipeline import FunctionTransformer, Pipeline, make_pipeline

from moabb.analysis.results import get_string_rep
from moabb.datasets.compound_dataset import compound
from moabb.datasets.fake import FakeDataset
from moabb.evaluations import evaluations as ev
from moabb.evaluations.utils import create_save_path, save_model_cv, save_model_list
Expand Down Expand Up @@ -82,6 +83,41 @@ def test_eval_results(self):
# We should have 9 columns in the results data frame
self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8)

def test_compound_dataset(self):
ch1 = ["C3", "Cz", "Fz"]
dataset1 = FakeDataset(
paradigm="imagery",
event_list=["left_hand", "right_hand"],
channels=ch1,
sfreq=128,
)
ch2 = ["C3", "C4", "Cz"]
dataset2 = FakeDataset(
paradigm="imagery",
event_list=["left_hand", "right_hand"],
channels=ch2,
sfreq=256,
)
merged_dataset = compound(dataset1, dataset2)

# We want to interpolate channels that are not in common between the two datasets
self.eval.paradigm.match_all(
merged_dataset.datasets, channel_merge_strategy="union"
)

process_pipeline = self.eval.paradigm.make_process_pipelines(dataset)[0]
results = [
r
for r in self.eval.evaluate(
dataset, pipelines, param_grid=None, process_pipeline=process_pipeline
)
]

# We should get 4 results, 2 sessions 2 subjects
self.assertEqual(len(results), 4)
# We should have 9 columns in the results data frame
self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8)

def test_eval_grid_search(self):
# Test grid search
param_grid = {"C": {"csp__metric": ["euclid", "riemann"]}}
Expand Down
Loading

0 comments on commit e7f465c

Please sign in to comment.