Skip to content

Commit

Permalink
Parallel evaluation (Cross-Session, Within-Session) (#364)
Browse files Browse the repository at this point in the history
* Changing the cross-session to include parallel

* Fixing pass with return

* Changing the evaluations.py

* Changing the Within evaluations.py

* Reverting

* Reverting again

* Parallel WithinSession

* Updating the evaluation, removing the yield

* Updating the evaluation, removing the yield

* Changing the parameter to base evaluation

* Adding verbose as true

* Fixing the issue =)

* Updating the whats_new.rst file

---------

Co-authored-by: Sylvain Chevallier <[email protected]>
  • Loading branch information
bruAristimunha and sylvchev authored May 31, 2023
1 parent a4daa12 commit 79f342b
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 162 deletions.
2 changes: 1 addition & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Develop branch
Enhancements
~~~~~~~~~~~~

- None
- Adding Parallel evaluation for :func:`moabb.evaluations.WithinSessionEvaluation` , :func:`moabb.evaluations.CrossSessionEvaluation` (:gh:`364` by `Bruno Aristimunha`_)

Bugs
~~~~
Expand Down
6 changes: 5 additions & 1 deletion moabb/evaluations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class BaseEvaluation(ABC):
If not None, can guarantee same seed for shuffling examples.
n_jobs: int, default=1
Number of jobs for fitting of pipeline.
n_jobs_evaluation: int, default=1
Number of jobs for evaluation, processing in parallel the within session,
cross-session or cross-subject.
overwrite: bool, default=False
If true, overwrite the results.
error_score: "raise" or numeric, default="raise"
Expand All @@ -52,6 +55,7 @@ def __init__(
datasets=None,
random_state=None,
n_jobs=1,
n_jobs_evaluation=1,
overwrite=False,
error_score="raise",
suffix="",
Expand All @@ -63,12 +67,12 @@ def __init__(
):
self.random_state = random_state
self.n_jobs = n_jobs
self.n_jobs_evaluation = n_jobs_evaluation
self.error_score = error_score
self.hdf5_path = hdf5_path
self.return_epochs = return_epochs
self.return_raws = return_raws
self.mne_labels = mne_labels

# check paradigm
if not isinstance(paradigm, BaseParadigm):
raise (ValueError("paradigm must be an Paradigm instance"))
Expand Down
Loading

0 comments on commit 79f342b

Please sign in to comment.