Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving the model #401

Merged
merged 38 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
71d5722
Updating README.md
bruAristimunha May 28, 2023
e0393b5
Merge remote-tracking branch 'origin/develop' into develop
bruAristimunha Jun 7, 2023
8b9aa2a
Merge remote-tracking branch 'origin/develop' into develop
bruAristimunha Jun 7, 2023
945349b
Merge remote-tracking branch 'origin/develop' into develop
bruAristimunha Jun 7, 2023
5defea6
Adding new saving
bruAristimunha Jun 19, 2023
a3d7293
Adding new saving model
bruAristimunha Jun 19, 2023
432e128
Adding new functions
bruAristimunha Jun 19, 2023
5748345
Merge branch 'develop' into saving_the_model
bruAristimunha Jun 19, 2023
c40cff8
updating the models doc
bruAristimunha Jun 19, 2023
84cb1f0
Merge remote-tracking branch 'origin/saving_the_model' into saving_th…
bruAristimunha Jun 19, 2023
1cac452
updating the models and evaluations
bruAristimunha Jun 19, 2023
6697c88
Update moabb/evaluations/evaluations.py
bruAristimunha Jun 19, 2023
bd03455
Generatic type
bruAristimunha Jun 19, 2023
3cbc1d0
adding if
bruAristimunha Jun 20, 2023
c0ba05d
Adding saving the best and changing the saving
bruAristimunha Jun 20, 2023
5e3252f
Solving Parallel and Saving Model
carraraig Jun 20, 2023
d1e3d00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 20, 2023
9d273c1
Adding if to hdf5_path is None
bruAristimunha Jun 20, 2023
e17aacd
Solving the new path
bruAristimunha Jun 20, 2023
db3f4b9
Adding new ifs
bruAristimunha Jun 20, 2023
057367c
Returning the Results folder
bruAristimunha Jun 20, 2023
5f44962
Solve Saved model on Pytorch
carraraig Jun 20, 2023
90feff4
Removing Keras models saving
bruAristimunha Jun 20, 2023
ad30335
Removing Keras models saving
bruAristimunha Jun 20, 2023
f383e14
Updating model_check
bruAristimunha Jun 20, 2023
0365ea3
Updated Saved model in Pytorch, second methodology if is a Skorch model
carraraig Jun 21, 2023
df6e0d1
Added Saved Model on Keras and Pytorch
carraraig Jun 21, 2023
ef6afff
Example Load model
carraraig Jun 21, 2023
161aad1
Updating the save model, optimizing the code
bruAristimunha Jun 21, 2023
3dab773
Fixing saving function
bruAristimunha Jun 21, 2023
4511edc
renaming model to step
bruAristimunha Jun 21, 2023
e16529b
Updating the tutorial
bruAristimunha Jun 21, 2023
9a80e76
Merge branch 'develop' into saving_the_model
bruAristimunha Jun 22, 2023
6d1584f
Updating the path
bruAristimunha Jun 22, 2023
c361bbf
Merge remote-tracking branch 'origin/saving_the_model' into saving_th…
bruAristimunha Jun 22, 2023
9e4ed4c
Adding new test and fix __init__.py
bruAristimunha Jun 22, 2023
f4a6abd
Adding new tests
bruAristimunha Jun 22, 2023
89df3ee
Updating whats new file
bruAristimunha Jun 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Enhancements
- Adding second deployment of the documentation (:gh:`374` by `Bruno Aristimunha`_)
- Adding Parallel evaluation for :func:`moabb.evaluations.WithinSessionEvaluation` , :func:`moabb.evaluations.CrossSessionEvaluation` (:gh:`364` by `Bruno Aristimunha`_)
- Add example with VirtualReality BrainInvaders dataset (:gh:`393` by `Gregoire Cattan`_ and `Pedro L. C. Rodrigues`_)
- Adding saving option for the models (:gh:`401` by `Bruno Aristimunha`_ and `Igor Carrara`_)
- Adding example to load different type of models (:gh:`401` by `Bruno Aristimunha`_ and `Igor Carrara`_)
- Add resting state paradigm with dataset and example (:gh:`400` by `Gregoire Cattan`_ and `Pedro L. C. Rodrigues`_)

Bugs
Expand All @@ -33,6 +35,7 @@ Bugs
- Restore 3 subject from Cho2017 (:gh:`392` by `Igor Carrara`_ and `Sylvain Chevallier`_)
- Correct downloading with VirtualReality BrainInvaders dataset (:gh:`393` by `Gregoire Cattan`_)
- Rename event `substraction` to `subtraction` in :func:`moabb.dataset.Shin2017B` (:gh:`397` by `Pierre Guetschel`_)
- Fixing issue with parallel evaluation (:gh:`401` by `Bruno Aristimunha`_ and `Igor Carrara`_)

API changes
~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions examples/advanced_examples/plot_grid_search_withinsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
path,
"GridSearch_WithinSession",
"001-2014",
"subject1",
"1",
"session_E",
"GridSearchEN",
"Grid_Search_WithinSession.pkl",
Expand All @@ -165,7 +165,7 @@
path,
"GridSearch_WithinSession",
"001-2014",
"subject1",
"1",
"session_T",
"GridSearchEN",
"Grid_Search_WithinSession.pkl",
Expand Down
123 changes: 123 additions & 0 deletions examples/plot_load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
==============================================
Load Model (Scikit, Pytorch, Keras) with MOABB
==============================================
This example shows how to use load the pretrained pipeline in MOABB.
"""
# Authors: Igor Carrara <[email protected]>
#
# License: BSD (3-clause)

from pickle import load

import keras
import torch
from braindecode import EEGClassifier
from braindecode.models import EEGInception
from scikeras.wrappers import KerasClassifier
from sklearn.pipeline import Pipeline
from skorch.callbacks import EarlyStopping, EpochScoring
from skorch.dataset import ValidSplit

from moabb import set_log_level
from moabb.pipelines.features import StandardScaler_Epoch
from moabb.pipelines.utils_pytorch import BraindecodeDatasetLoader, InputShapeSetterEEG
from moabb.utils import setup_seed


set_log_level("info")

###############################################################################
# In this example, we will use the results computed by the following examples
#
# - plot_benchmark_
# - plot_benchmark_braindecode_
# - plot_benchmark_DL_
# ---------------------

# Set up reproducibility of Tensorflow and PyTorch
setup_seed(42)

###############################################################################
# Loading the Scikit-learn pipelines

with open(
"./results/Models_WithinSession/Zhou 2016/1/session_0/CSP + SVM/fitted_model_best.pkl",
"rb",
) as pickle_file:
CSP_SVM_Trained = load(pickle_file)

###############################################################################
# Loading the Keras model
# We load the single Keras model, if we want we can set in the exact same pipeline.

model_Keras = keras.models.load_model(
"./results/Models_WithinSession/001-2014/1/session_E/Keras_DeepConvNet/kerasdeepconvnet_fitted_model_best.h5"
)
# Now we need to instantiate a new SciKeras object since we only saved the Keras model
Keras_DeepConvNet_Trained = KerasClassifier(model_Keras)
# Create the pipelines


pipes_keras = Pipeline(
[
("StandardScaler_Epoch", StandardScaler_Epoch),
("Keras_DeepConvNet_Trained", Keras_DeepConvNet_Trained),
]
)


###############################################################################
# Loading the PyTorch model

# Set EEG Inception model
model = EEGInception(in_channels=22, n_classes=2)

# Hyperparameter
LEARNING_RATE = 0.0001
WEIGHT_DECAY = 0
BATCH_SIZE = 64
SEED = 42
VERBOSE = 1
EPOCH = 2
PATIENCE = 3

# Define a Skorch classifier
clf = EEGClassifier(
module=model,
criterion=torch.nn.CrossEntropyLoss,
optimizer=torch.optim.Adam,
optimizer__lr=LEARNING_RATE,
batch_size=BATCH_SIZE,
max_epochs=EPOCH,
train_split=ValidSplit(0.2, random_state=SEED),
callbacks=[
EarlyStopping(monitor="valid_loss", patience=PATIENCE),
EpochScoring(
scoring="accuracy", on_train=True, name="train_acc", lower_is_better=False
),
EpochScoring(
scoring="accuracy", on_train=False, name="valid_acc", lower_is_better=False
),
InputShapeSetterEEG(
params_list=["in_channels", "input_window_samples", "n_classes"],
),
],
verbose=VERBOSE, # Not printing the results for each epoch
)

clf.initialize()

f_params = "./results/Models_CrossSession/001-2014/1/braindecode_EEGInception/EEGInception_fitted_best_model.pkl"
f_optimizer = "./results/Models_CrossSession/001-2014/1/braindecode_EEGInception/EEGInception_fitted_best_optim.pkl"
f_history = "./results/Models_CrossSession/001-2014/1/braindecode_EEGInception/EEGInception_fitted_best_history.json"

clf.load_params(f_params=f_params, f_optimizer=f_optimizer, f_history=f_history)


# Create the dataset
create_dataset = BraindecodeDatasetLoader(drop_last_window=False)

# Create the pipelines
pipes_pytorch = Pipeline([("Braindecode_dataset", create_dataset), ("EEGInception", clf)])
8 changes: 7 additions & 1 deletion moabb/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def benchmark( # noqa: C901
overwrite=False,
output="./benchmark/",
n_jobs=-1,
n_jobs_evaluation=1,
plot=False,
contexts=None,
include_datasets=None,
Expand Down Expand Up @@ -85,6 +86,9 @@ def benchmark( # noqa: C901
Folder to store the analysis results
n_jobs: int
Number of threads to use for running parallel jobs
n_jobs_evaluation: int, default=1
Number of jobs for evaluation, processing in parallel the within session,
cross-session or cross-subject.
plot: bool
Plot results after computing
contexts: str
Expand Down Expand Up @@ -172,7 +176,8 @@ def benchmark( # noqa: C901
datasets=d,
random_state=42,
hdf5_path=results,
n_jobs=1,
n_jobs=n_jobs,
n_jobs_evaluation=n_jobs_evaluation,
overwrite=overwrite,
return_epochs=True,
)
Expand All @@ -192,6 +197,7 @@ def benchmark( # noqa: C901
random_state=42,
hdf5_path=results,
n_jobs=n_jobs,
n_jobs_evaluation=n_jobs_evaluation,
overwrite=overwrite,
)
paradigm_results = context.process(
Expand Down
1 change: 1 addition & 0 deletions moabb/evaluations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
CrossSubjectEvaluation,
WithinSessionEvaluation,
)
from .utils import create_save_path, save_model_cv, save_model_list
Loading