Skip to content

Commit

Permalink
Update cVEP paradigm according to the changes of NeuroTechX#408
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreGtch committed Aug 18, 2023
1 parent ce2dd52 commit 3b28bbe
Showing 1 changed file with 10 additions and 172 deletions.
182 changes: 10 additions & 172 deletions moabb/paradigms/cvep.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
"""v-VEP Paradigms"""

import abc
import logging

import mne
import numpy as np
import pandas as pd

from moabb.datasets import utils
from moabb.datasets.fake import FakeDataset
from moabb.paradigms.base import BaseParadigm
Expand Down Expand Up @@ -67,19 +62,15 @@ def __init__(
channels=None,
resample=None,
):
super().__init__()
self.filters = filters
self.events = events
self.baseline = baseline
self.channels = channels
self.resample = resample

if tmax is not None:
if tmin >= tmax:
raise (ValueError("tmax must be greater than tmin"))

self.tmin = tmin
self.tmax = tmax
super().__init__(
filters=filters,
events=events,
channels=channels,
baseline=baseline,
resample=resample,
tmin=tmin,
tmax=tmax,
)

def is_valid(self, dataset):
ret = True
Expand All @@ -93,159 +84,6 @@ def is_valid(self, dataset):

return ret

@abc.abstractmethod
def used_events(self, dataset):
pass

def process_raw(self, raw, dataset, return_epochs=False, return_raws=False):
"""
Process one raw data file.
This function applies the preprocessing and eventual epoching on the
individual run, and return the data, labels and a dataframe with
metadata.
metadata is a dataframe with as many row as the length of the data
and labels.
Parameters
----------
raw: mne.Raw instance
the raw EEG data.
dataset : dataset instance
The dataset corresponding to the raw file. Mainly used to access
dataset specific information.
return_epochs: boolean
This flag specifies whether to return only the data array or the
complete processed mne.Epochs
return_raws: boolean
To return raw files and events, to ensure compatibility with braindecode.
Mutually exclusive with return_epochs
returns
-------
X : Union[np.ndarray, mne.Epochs]
the data that will be used as features for the model
Note: if return_epochs=True, this is mne.Epochs
if return_epochs=False, this is np.ndarray
labels: np.ndarray
the labels for training / evaluating the model
metadata: pd.DataFrame
A dataframe containing the metadata
"""

if return_epochs and return_raws:
message = "Select only return_epochs or return_raws, not both"
raise ValueError(message)

# get events id
event_id = self.used_events(dataset)

# find the events, first check stim_channels then annotations
stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
if len(stim_channels) > 0:
events = mne.find_events(raw, shortest_event=0, verbose=False)
else:
log.warning(f"No matching stim channels found in {raw.filenames}")

# picks channels
if self.channels is None:
picks = mne.pick_types(raw.info, eeg=True, stim=False)
else:
picks = mne.pick_channels(
raw.info["ch_names"], include=self.channels, ordered=True
)

# pick events, based on event_id
try:
if "Target" in event_id and "NonTarget" in event_id:
if isinstance(event_id["Target"], list) and isinstance(
event_id["NonTarget"], list
):
event_id_new = dict(Target=1, NonTarget=0)
events = mne.merge_events(events, event_id["Target"], 1)
events = mne.merge_events(events, event_id["NonTarget"], 0)
event_id = event_id_new
events = mne.pick_events(events, include=list(event_id.values()))
except RuntimeError:
# skip raw if no event found
return

if return_raws:
raw = raw.pick(picks)
else:
# get interval
tmin = self.tmin + dataset.interval[0]
if self.tmax is None:
tmax = dataset.interval[1]
else:
tmax = self.tmax + dataset.interval[0]

X = []
for bandpass in self.filters:
fmin, fmax = bandpass
# filter data
raw_f = raw.copy().filter(
fmin, fmax, method="iir", picks=picks, verbose=False
)
# epoch data
baseline = self.baseline
if baseline is not None:
baseline = (
self.baseline[0] + dataset.interval[0],
self.baseline[1] + dataset.interval[0],
)
bmin = baseline[0] if baseline[0] < tmin else tmin
bmax = baseline[1] if baseline[1] > tmax else tmax
else:
bmin = tmin
bmax = tmax
epochs = mne.Epochs(
raw_f,
events,
event_id=event_id,
tmin=bmin,
tmax=bmax,
proj=False,
baseline=baseline,
preload=True,
verbose=False,
picks=picks,
event_repeated="drop",
on_missing="ignore",
)
if bmin < tmin or bmax > tmax:
epochs.crop(tmin=tmin, tmax=tmax)
if self.resample is not None:
epochs = epochs.resample(self.resample)
# rescale to work with uV
if return_epochs:
X.append(epochs)
else:
X.append(dataset.unit_factor * epochs.get_data())

# overwrite events in case epochs have been dropped:
# (assuming all filters produce the same number of epochs...)
events = epochs.events

inv_events = {k: v for v, k in event_id.items()}
labels = np.array([inv_events[e] for e in events[:, -1]])

if return_epochs:
X = mne.concatenate_epochs(X)
elif return_raws:
X = raw
elif len(self.filters) == 1:
# if only one band, return a 3D array
X = X[0]
else:
# otherwise return a 4D
X = np.array(X).transpose((1, 2, 3, 0))

metadata = pd.DataFrame(index=range(len(labels)))
return X, labels, metadata

@property
def datasets(self):
if self.tmax is None:
Expand Down Expand Up @@ -348,4 +186,4 @@ def datasets(self):
return [FakeDataset(["Target", "NonTarget"], paradigm="cvep")]

def is_valid(self, dataset):
return True
return dataset.paradigm == "cvep"

0 comments on commit 3b28bbe

Please sign in to comment.