Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug for MotorImagery All events #337

Merged
merged 7 commits into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Bugs
- Correct usage of name simplification function in analyze (:gh:`306` by `Divyesh Narayanan`_)
- Fix downloading path issue for Weibo2014 and Zhou2016, numy error in DemonsP300 (:gh:`315` by `Sylvain Chevallier`_)
- Fix unzip error for Huebner2017 and Huebner2018 (:gh:`318` by `Sylvain Chevallier`_)
- Fix n_classes when events set to None (:gh:`337` by `Igor Carrara`_ and `Sylvain Chevallier`_)

API changes
~~~~~~~~~~~
Expand Down
14 changes: 8 additions & 6 deletions moabb/paradigms/motor_imagery.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,25 +346,27 @@ class MotorImagery(SinglePass):
If not None, resample the eeg data with the sampling rate provided.
"""

def __init__(self, n_classes=2, **kwargs):
def __init__(self, n_classes=None, **kwargs):
super().__init__(**kwargs)
self.n_classes = n_classes

if self.events is None:
log.warning("Choosing from all possible events")
else:
elif self.n_classes is not None:
assert n_classes <= len(self.events), "More classes than events specified"

def is_valid(self, dataset):
ret = True
if not dataset.paradigm == "imagery":
ret = False
if self.events is None:
elif self.n_classes is None and self.events is None:
pass
elif self.events is None:
if not len(dataset.event_id) >= self.n_classes:
ret = False
else:
overlap = len(set(self.events) & set(dataset.event_id.keys()))
if not overlap >= self.n_classes:
if self.n_classes is not None and not overlap >= self.n_classes:
ret = False
return ret

Expand All @@ -373,8 +375,8 @@ def used_events(self, dataset):
if self.events is None:
for k, v in dataset.event_id.items():
out[k] = v
if len(out) == self.n_classes:
break
if self.n_classes is None:
self.n_classes = len(out)
else:
for event in self.events:
if event in dataset.event_id.keys():
Expand Down