Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Braindecode pipeline #328

Merged
merged 54 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
4c59cfd
Adding braindecode object as a pipeline.
bruAristimunha Feb 7, 2023
4d0b169
Adding in init file and changing the name.
bruAristimunha Feb 7, 2023
2037230
Adding new pipeline.
bruAristimunha Feb 7, 2023
3756a76
Adding new dependence, braindecode
bruAristimunha Feb 7, 2023
2eb5a4c
Adding new dependence, torch
bruAristimunha Feb 7, 2023
e5a0534
Merge branch 'develop' into braindecode
bruAristimunha Feb 7, 2023
42f8189
Merge branch 'develop' into braindecode
bruAristimunha Feb 8, 2023
b42657c
Updating the dependencies, set as optional.
bruAristimunha Feb 9, 2023
187954a
Setting the Valid Split to the new pipeline.
bruAristimunha Feb 9, 2023
0706b48
Merge remote-tracking branch 'origin/braindecode' into braindecode
bruAristimunha Feb 9, 2023
3334dfd
restoring the file
bruAristimunha Feb 12, 2023
6fb5b45
Merge branch 'develop' into braindecode
bruAristimunha Feb 13, 2023
eb3b460
Merge branch 'develop' into braindecode
bruAristimunha Feb 14, 2023
4d22b40
Merge branch 'develop' into braindecode
bruAristimunha Mar 3, 2023
1eaf054
Merge branch 'develop' into braindecode
bruAristimunha Mar 3, 2023
0697459
Updating __init__
bruAristimunha Mar 11, 2023
bea1d6b
Removing the BraindecodeClassifierModel
bruAristimunha Mar 11, 2023
555bf0c
Updating EEGClassifier to use the max_epochs
bruAristimunha Mar 11, 2023
76a0752
Adding braindecode as depedencies
bruAristimunha Mar 11, 2023
014162c
Moving the file to other nome
bruAristimunha Mar 11, 2023
df77fda
Merge branch 'develop' into braindecode
bruAristimunha Mar 11, 2023
039c73d
Adding support ot braindecode classifier
bruAristimunha Mar 11, 2023
a801d67
Adding y as None value
bruAristimunha Mar 11, 2023
bcbeed1
first iteration to use the ShallowNet from braindecode with yaml file
bruAristimunha Mar 11, 2023
ff470cb
Adding as example
bruAristimunha Mar 11, 2023
c9ff708
To discuss
bruAristimunha Mar 11, 2023
d70e1fe
Merge branch 'develop' into braindecode
bruAristimunha Mar 19, 2023
c557d98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2023
eafc342
Removing braindecode file
bruAristimunha Mar 19, 2023
2fe6cd2
Adding yaml files
bruAristimunha Mar 19, 2023
47bd4d5
Removing ShallowFBCSPNet, duplicate
bruAristimunha Mar 19, 2023
846de60
Updating Braindecode_ShallowFBCSPNET.yml to add the inputShapeSetterEEG
bruAristimunha Mar 19, 2023
37ba988
working on the parser
bruAristimunha Mar 21, 2023
32d89b5
removing braindecode keyword
bruAristimunha Mar 21, 2023
df2f933
working more on the parser
bruAristimunha Mar 22, 2023
afcdc10
Merge branch 'develop' into braindecode
bruAristimunha Mar 28, 2023
adb2914
Adding one more test
bruAristimunha Mar 28, 2023
f0a3479
Adding check module
bruAristimunha Mar 28, 2023
51c11a0
Merge remote-tracking branch 'origin/braindecode' into braindecode
bruAristimunha Mar 28, 2023
6851075
Improving the BraindecodeDatasetLoader
bruAristimunha Mar 28, 2023
934f407
Removing the yaml file for braindecode object
bruAristimunha Mar 28, 2023
2353869
Improving the examples
bruAristimunha Mar 28, 2023
7eeae68
Naming the variable
bruAristimunha Mar 28, 2023
21fa441
Returning the old parser
bruAristimunha Mar 28, 2023
a72555e
Merge branch 'develop' into braindecode
bruAristimunha Mar 28, 2023
4e1b691
adding test folder
bruAristimunha Mar 28, 2023
094fe42
Merge remote-tracking branch 'origin/braindecode' into braindecode
bruAristimunha Mar 28, 2023
87a5c28
fix: correct error when multiple pipelines
Mar 29, 2023
e42fa17
fix: correct doc error
Mar 29, 2023
c77eefb
revert: leave braindecode order
Mar 29, 2023
ee7b3a7
fix: benchmark unit test passed
Mar 29, 2023
13fd28a
fix: doc building error
Mar 29, 2023
4d5cab5
Update moabb/benchmark.py
bruAristimunha Apr 4, 2023
f4e3b36
Merge branch 'develop' into braindecode
sylvchev Apr 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ source = moabb
omit =
*/docs/*
*/pipelines/*
*/tests/*

[report]
exclude_lines =
Expand Down
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Enhancements
- Add a augmentation model to the pipeline (:gh:`326` by `Igor Carrara`_)
- Add BrainDecode example(:gh:`340` by `Igor Carrara`_ and `Bruno Aristimunha`_)
- Add Google Analytics to the documentation (:gh:`335` by `Bruno Aristimunha`_)
- Add suport to Braindecode classifier (:gh:`328` by `Bruno Aristimunha`_)
- Add CodeCarbon to track emission CO₂ (:gh:`350` by `Igor Carrara`_, `Bruno Aristimunha`_ and `Sylvain Chevallier`_)
- Add CodeCarbon example (:gh:`356` by `Igor Carrara`_ and `Bruno Aristimunha`_)

Expand Down
4 changes: 2 additions & 2 deletions examples/pipelines_braindecode/braindecode_EEGInception.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
PATIENCE = 3

# Create the dataset
create_dataset = BraindecodeDatasetLoader()
create_dataset = BraindecodeDatasetLoader(drop_last_window=False)

# Set EEG Inception model
model = EEGInception(in_channels=1, n_classes=2, input_window_samples=100)
model = EEGInception(in_channels=1, n_classes=2)

# Define a Skorch classifier
clf = EEGClassifier(
Expand Down
2 changes: 1 addition & 1 deletion examples/pipelines_braindecode/braindecode_EEGNetv4.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
PATIENCE = 3

# Create the dataset
create_dataset = BraindecodeDatasetLoader()
create_dataset = BraindecodeDatasetLoader(drop_last_window=False)

# Set EEGNetv4 model
model = EEGNetv4(in_chans=1, n_classes=2, input_window_samples=100)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
PATIENCE = 3

# Create the dataset
create_dataset = BraindecodeDatasetLoader()
create_dataset = BraindecodeDatasetLoader(drop_last_window=False)

# Set Shallow Filter Bank CSP Net model
model = ShallowFBCSPNet(
Expand Down
66 changes: 44 additions & 22 deletions moabb/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
log = logging.getLogger(__name__)


def benchmark(
def benchmark( # noqa: C901
pipelines="./pipelines/",
evaluations=None,
paradigms=None,
Expand Down Expand Up @@ -154,27 +154,49 @@ def benchmark(
f"Datasets considered for {paradigm} paradigm {[dt.code for dt in d]}"
)

if "braindecode" in list(prdgms[paradigm].keys())[0]:
return_epochs = True
else:
return_epochs = False

context = eval_type[evaluation](
paradigm=p,
datasets=d,
random_state=42,
hdf5_path=results,
n_jobs=n_jobs,
overwrite=overwrite,
return_epochs=return_epochs,
)
paradigm_results = context.process(
pipelines=prdgms[paradigm], param_grid=param_grid
)
paradigm_results["paradigm"] = f"{paradigm}"
paradigm_results["evaluation"] = f"{evaluation}"
eval_results[f"{paradigm}"] = paradigm_results
df_eval.append(paradigm_results)
ppl_with_epochs, ppl_with_array = {}, {}
for pn, pv in prdgms[paradigm].items():
if "braindecode" in pn:
ppl_with_epochs[pn] = pv
else:
ppl_with_array[pn] = pv

if len(ppl_with_epochs) > 0:
# Braindecode pipelines require return_epochs=True
context = eval_type[evaluation](
paradigm=p,
datasets=d,
random_state=42,
hdf5_path=results,
n_jobs=1,
overwrite=overwrite,
return_epochs=True,
)
paradigm_results = context.process(
pipelines=ppl_with_epochs, param_grid=param_grid
)
paradigm_results["paradigm"] = f"{paradigm}"
paradigm_results["evaluation"] = f"{evaluation}"
eval_results[f"{paradigm}"] = paradigm_results
df_eval.append(paradigm_results)

# Other pipelines, that use numpy arrays
if len(ppl_with_array) > 0:
context = eval_type[evaluation](
paradigm=p,
datasets=d,
random_state=42,
hdf5_path=results,
n_jobs=n_jobs,
overwrite=overwrite,
)
paradigm_results = context.process(
pipelines=ppl_with_array, param_grid=param_grid
)
paradigm_results["paradigm"] = f"{paradigm}"
paradigm_results["evaluation"] = f"{evaluation}"
eval_results[f"{paradigm}"] = paradigm_results
df_eval.append(paradigm_results)

# Combining FilterBank and direct paradigms
eval_results = _combine_paradigms(eval_results)
Expand Down
4 changes: 3 additions & 1 deletion moabb/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
Pipelines are typically a chain of sklearn compatible transformers and end
with a sklearn compatible estimator.
"""

# flake8: noqa

from .classification import SSVEP_CCA, SSVEP_TRCA
from .features import FM, AugmentedDataset, ExtendedSSVEPSignal, LogVariance
from .utils import FilterBank, create_pipeline_from_config
Expand All @@ -20,7 +22,7 @@
)
from .utils_deep_model import EEGNet, TCN_block
except ModuleNotFoundError as err:
print("Tensorflow not install, you could not use deep learning pipelines")
print("Tensorflow not install, you could not use those pipelines")

try:
from .utils_pytorch import (
Expand Down
32 changes: 29 additions & 3 deletions moabb/pipelines/utils_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,58 @@
from functools import partial
from inspect import getmembers, isclass, isroutine

import mne
from braindecode.datasets import BaseConcatDataset, create_from_X_y
from numpy import unique
from sklearn.base import BaseEstimator, TransformerMixin
from skorch.callbacks import Callback
from torch.nn import Module


# check if the data format is numpy or mne epoch
def _check_data_format(X):
"""
Check if the data format is compatible with braindecode.
Expect values in the format of MNE objects.
Parameters
----------
X: BaseConcatDataset

Returns
-------

"""
if not isinstance(X, mne.EpochsArray):
raise ValueError(
"The data format is not supported. "
"Please use the option return_epochs=True"
"inside the Evaluations module."
)


class BraindecodeDatasetLoader(BaseEstimator, TransformerMixin):
"""
Class to Load the data from MOABB in a format compatible with braindecode
"""

def __init__(self, kw_args=None):
def __init__(self, drop_last_window=False, kw_args=None):
self.drop_last_window = drop_last_window
self.kw_args = kw_args

def fit(self, X, y=None):
_check_data_format(X)
self.y = y
return self

def transform(self, X, y=None):
_check_data_format(X)
dataset = create_from_X_y(
X.get_data(),
X=X.get_data(),
y=self.y,
window_size_samples=X.get_data().shape[2],
window_stride_samples=X.get_data().shape[2],
drop_last_window=False,
drop_last_window=self.drop_last_window,
ch_names=X.info["ch_names"],
sfreq=X.info["sfreq"],
)

Expand Down
6 changes: 6 additions & 0 deletions moabb/tests/util_braindecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def test_type_create_from_X_y_vs_transfomer(self, data):
assert isinstance(dataset_trans, BaseConcatDataset)
assert type(dataset_trans) == type(dataset)

def test_wrong_input(self):
"""Test that an invalid input raises a ValueError"""
transformer = BraindecodeDatasetLoader()
with pytest.raises(ValueError):
transformer.fit_transform(np.random.normal(size=(2, 1, 10)), y=np.array([0]))


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions moabb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def setup_seed(seed: int) -> None:


def set_log_level(level="INFO"):
"""Set lot level.
"""Set log level

Set the general log level.
Use one of the levels supported by python logging, i.e.:
DEBUG, INFO, WARNING, ERROR, CRITICAL
DEBUG, INFO, WARNING, ERROR, CRITICAL
"""
VALID_LEVELS = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
level = level.upper()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pre-commit = "^2.21.0"
m2r2 = "^0.3.3"
tdlda = {git = "https://github.com/jsosulski/tdlda.git", rev = "0.1.0"}

bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Expand Down