Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] P300-VR dataset #393

Merged
merged 34 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
910090d
In some places, the virtual reality dataset code was wrong.
gcattan May 31, 2023
c12940d
fix: PC data not downloading.
gcattan May 31, 2023
cd96bea
push example from Pedro
gcattan May 31, 2023
da628d5
fix error with datframe initialization
gcattan May 31, 2023
c083802
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
69d73ee
add whats new
gcattan May 31, 2023
09f6ca2
add test
gcattan May 31, 2023
e53c883
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
a3ee57c
fix pytest/unittest
gcattan May 31, 2023
a1296a2
Merge branch 'develop' of github.com:gcattan/moabb into develop
gcattan May 31, 2023
a402bd3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
3e9405d
replace logging by warnings library
gcattan May 31, 2023
6cf2a79
Merge branch 'develop' of github.com:gcattan/moabb into develop
gcattan May 31, 2023
6695995
move docstring to the top
gcattan May 31, 2023
2713cec
Merge branch 'develop' into develop
gcattan May 31, 2023
370c05f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
19e70f5
Merge branch 'develop' into develop
bruAristimunha May 31, 2023
7e47a46
test completed
gcattan May 31, 2023
7e91ffc
Merge branch 'develop' of github.com:gcattan/moabb into develop
gcattan May 31, 2023
183d354
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
ac1ed24
leftover
gcattan May 31, 2023
ef2a214
Merge branch 'develop' of github.com:gcattan/moabb into develop
gcattan May 31, 2023
6b5dc5f
typo ><
gcattan May 31, 2023
46be265
Merge branch 'develop' into develop
sylvchev May 31, 2023
4d21fe4
Update examples/vr_pc_p300_different_epoch_size.py
gcattan Jun 1, 2023
367a98f
rename into plot_vr_pc_p300_different_epoch_size.py
gcattan Jun 1, 2023
7d5644e
- Add figure plot
gcattan Jun 2, 2023
ea0f1fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2023
01dea81
Merge branch 'develop' into develop
bruAristimunha Jun 2, 2023
d8bd45a
Merge branch 'develop' into develop
gcattan Jun 5, 2023
182c575
Merge branch 'develop' into develop
bruAristimunha Jun 7, 2023
75b4ac6
Merge branch 'develop' into develop
bruAristimunha Jun 7, 2023
f360bf7
Update plot_vr_pc_p300_different_epoch_size.py
gcattan Jun 7, 2023
7007b1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ Enhancements

- Adding second deployment of the documentation (:gh:`374` by `Bruno Aristimunha`_)
- Adding Parallel evaluation for :func:`moabb.evaluations.WithinSessionEvaluation` , :func:`moabb.evaluations.CrossSessionEvaluation` (:gh:`364` by `Bruno Aristimunha`_)
- Add example with VirtualReality BrainInvaders dataset (:gh:`393` by `Gregoire Cattan`_ and `Pedro L. C. Rodrigues`_)

Bugs
~~~~

- Restore 3 subject from Cho2017 (:gh:`392` by `Igor Carrara`_ and `Sylvain Chevallier`_)
- Correct downloading with VirtualReality BrainInvaders dataset (:gh:`393` by `Gregoire Cattan`_)

API changes
~~~~~~~~~~~
Expand Down
93 changes: 93 additions & 0 deletions examples/vr_pc_p300_different_epoch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
sylvchev marked this conversation as resolved.
Show resolved Hide resolved
=============================
Classification of the trials
gcattan marked this conversation as resolved.
Show resolved Hide resolved
=============================

This example shows how to extract the epochs from the P300-VR dataset of a given
subject and then classify them using Riemannian Geometry framework for BCI.
We compare the scores in the VR and PC conditions, using different epoch size.

"""
# Authors: Pedro Rodrigues <[email protected]>
# Modified by: Gregoire Cattan <[email protected]>
# License: BSD (3-clause)

import warnings

import numpy as np
import pandas as pd
from pyriemann.classification import MDM
from pyriemann.estimation import ERPCovariances
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

from moabb.datasets import VirtualReality
from moabb.paradigms import P300


warnings.filterwarnings("ignore")

# create dataset
dataset = VirtualReality()

# To encode classes into 0 and 1.
le = LabelEncoder().fit(["Target", "NonTarget"])

# get the paradigm
paradigm = P300()

# change this to include more subjects
nsubjects = 2

scores = []
for tmax in [0.2, 1.0]:
paradigm.tmax = tmax

for subject in tqdm(dataset.subject_list[:nsubjects]):
scores_subject = [tmax, subject]

for condition in ["VR", "PC"]:
print(f"subject {subject}, {condition}, tmax {tmax}")

# define the dataset instance
dataset.virtual_reality = condition == "VR"
dataset.personal_computer = condition == "PC"

# cross validate with 3-folds validation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use a WithinSessionEvaluation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to demonstrate the use of the get_block_repetition method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, in that case, you should add a small text it in the example to explain get_block_repetition method and to indicate that this example use it.

kf = KFold(n_splits=3)

# There is 12 blocks of 5 repetitions.
repetitions = [1, 2] # Select the first two repetitions.
blocks = np.arange(1, 12 + 1)

auc = []

for train_idx, test_idx in kf.split(np.arange(12)):
# split in training and testing blocks
X_train, y_train, _ = dataset.get_block_repetition(
paradigm, [subject], blocks[train_idx], repetitions
)

X_test, y_test, _ = dataset.get_block_repetition(
paradigm, [subject], blocks[test_idx], repetitions
)

pipe = make_pipeline(ERPCovariances(estimator="lwf"), MDM())
pipe.fit(X_train, y_train)
y_pred = pipe.predict(X_test)

y_test = le.transform(y_test)
y_pred = le.transform(y_pred)

auc.append(roc_auc_score(y_test, y_pred))

# stock scores
scores_subject.append(np.mean(auc))

scores.append(scores_subject)

df = pd.DataFrame(scores, columns=["tmax", "subject", "VR", "PC"])
print(df)
28 changes: 17 additions & 11 deletions moabb/datasets/braininvaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import zipfile as z
from distutils.dir_util import copy_tree
from warnings import warn

import mne
import numpy as np
Expand Down Expand Up @@ -150,7 +151,7 @@ def _bi_get_subject_data(ds, subject): # noqa: C901
stim[idx_nontarget] = 1
X = np.concatenate([S, stim[None, :]])
sfreq = 512
elif ds.code == "Virtual Reality dataset":
elif ds.code == "P300-VR":
data = loadmat(os.path.join(file_path, os.listdir(file_path)[0]))["data"]

chnames = [
Expand Down Expand Up @@ -187,7 +188,7 @@ def _bi_get_subject_data(ds, subject): # noqa: C901
verbose=False,
)

if not ds.code == "Virtual Reality dataset":
if not ds.code == "P300-VR":
raw = mne.io.RawArray(data=X, info=info, verbose=False)
raw.set_montage(make_standard_montage("standard_1020"))

Expand Down Expand Up @@ -388,15 +389,16 @@ def _bi_data_path( # noqa: C901
)
for i in range(1, 5)
]
elif ds.code == "Virtual Reality dataset":
elif ds.code == "P300-VR":
subject_paths = []
url = "{:s}subject_{:02d}_{:s}.mat".format(
VIRTUALREALITY_URL,
subject,
"VR" if ds.virtual_reality else ds.personal_computer,
)
file_path = dl.data_path(url, "VIRTUALREALITY")
subject_paths.append(file_path)
if ds.virtual_reality:
url = "{:s}subject_{:02d}_{:s}.mat".format(VIRTUALREALITY_URL, subject, "VR")
file_path = dl.data_path(url, "VIRTUALREALITY")
subject_paths.append(file_path)
if ds.personal_computer:
url = "{:s}subject_{:02d}_{:s}.mat".format(VIRTUALREALITY_URL, subject, "PC")
file_path = dl.data_path(url, "VIRTUALREALITY")
subject_paths.append(file_path)

return subject_paths

Expand Down Expand Up @@ -868,6 +870,10 @@ def __init__(self, virtual_reality=False, screen_display=True):

self.virtual_reality = virtual_reality
self.personal_computer = screen_display
if not self.virtual_reality and not self.personal_computer:
warn(
"[P300-VR dataset] virtual_reality and screen display are False. No data will be downloaded, unless you change these parameters after initialization."
)

def _get_single_subject_data(self, subject):
"""return data for a single subject"""
Expand All @@ -880,7 +886,7 @@ def data_path(

def get_block_repetition(self, paradigm, subjects, block_list, repetition_list):
"""Select data for all provided subjects, blocks and repetitions.
Each subject has 5 blocks of 12 repetitions.
Each subject has 12 blocks of 5 repetitions.

The returned data is a dictionary with the folowing structure::

Expand Down
11 changes: 11 additions & 0 deletions moabb/tests/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ def __init__(self, *args, **kwargs):
def test_canary(self):
assert VirtualReality() is not None

def test_warning_if_parameters_false(self):
with self.assertWarns(UserWarning):
VirtualReality(virtual_reality=False, screen_display=False)

def test_data_path(self):
ds = VirtualReality(virtual_reality=True, screen_display=True)
data_path = ds.data_path(1)
assert len(data_path) == 2
assert "subject_01_VR.mat" in data_path[0]
assert "subject_01_PC.mat" in data_path[1]

def test_get_block_repetition(self):
ds = FakeVirtualRealityDataset()
subject = 5
Expand Down