-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Updating README.md * Adding new saving * Adding new saving model * Adding new functions * updating the models doc * updating the models and evaluations * Update moabb/evaluations/evaluations.py * Generatic type * adding if * Adding saving the best and changing the saving * Solving Parallel and Saving Model * [pre-commit.ci] auto fixes from pre-commit.com hooks * Adding if to hdf5_path is None * Solving the new path * Adding new ifs * Returning the Results folder * Solve Saved model on Pytorch * Removing Keras models saving * Updating model_check * Updated Saved model in Pytorch, second methodology if is a Skorch model * Added Saved Model on Keras and Pytorch * Example Load model * Updating the save model, optimizing the code * Fixing saving function * renaming model to step * Updating the tutorial * Updating the path * Adding new test and fix __init__.py * Adding new tests * Updating whats new file --------- Co-authored-by: CARRARA Igor <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
2821954
commit 2b38df2
Showing
8 changed files
with
682 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
""" | ||
============================================== | ||
Load Model (Scikit, Pytorch, Keras) with MOABB | ||
============================================== | ||
This example shows how to use load the pretrained pipeline in MOABB. | ||
""" | ||
# Authors: Igor Carrara <[email protected]> | ||
# | ||
# License: BSD (3-clause) | ||
|
||
from pickle import load | ||
|
||
import keras | ||
import torch | ||
from braindecode import EEGClassifier | ||
from braindecode.models import EEGInception | ||
from scikeras.wrappers import KerasClassifier | ||
from sklearn.pipeline import Pipeline | ||
from skorch.callbacks import EarlyStopping, EpochScoring | ||
from skorch.dataset import ValidSplit | ||
|
||
from moabb import set_log_level | ||
from moabb.pipelines.features import StandardScaler_Epoch | ||
from moabb.pipelines.utils_pytorch import BraindecodeDatasetLoader, InputShapeSetterEEG | ||
from moabb.utils import setup_seed | ||
|
||
|
||
set_log_level("info") | ||
|
||
############################################################################### | ||
# In this example, we will use the results computed by the following examples | ||
# | ||
# - plot_benchmark_ | ||
# - plot_benchmark_braindecode_ | ||
# - plot_benchmark_DL_ | ||
# --------------------- | ||
|
||
# Set up reproducibility of Tensorflow and PyTorch | ||
setup_seed(42) | ||
|
||
############################################################################### | ||
# Loading the Scikit-learn pipelines | ||
|
||
with open( | ||
"./results/Models_WithinSession/Zhou 2016/1/session_0/CSP + SVM/fitted_model_best.pkl", | ||
"rb", | ||
) as pickle_file: | ||
CSP_SVM_Trained = load(pickle_file) | ||
|
||
############################################################################### | ||
# Loading the Keras model | ||
# We load the single Keras model, if we want we can set in the exact same pipeline. | ||
|
||
model_Keras = keras.models.load_model( | ||
"./results/Models_WithinSession/001-2014/1/session_E/Keras_DeepConvNet/kerasdeepconvnet_fitted_model_best.h5" | ||
) | ||
# Now we need to instantiate a new SciKeras object since we only saved the Keras model | ||
Keras_DeepConvNet_Trained = KerasClassifier(model_Keras) | ||
# Create the pipelines | ||
|
||
|
||
pipes_keras = Pipeline( | ||
[ | ||
("StandardScaler_Epoch", StandardScaler_Epoch), | ||
("Keras_DeepConvNet_Trained", Keras_DeepConvNet_Trained), | ||
] | ||
) | ||
|
||
|
||
############################################################################### | ||
# Loading the PyTorch model | ||
|
||
# Set EEG Inception model | ||
model = EEGInception(in_channels=22, n_classes=2) | ||
|
||
# Hyperparameter | ||
LEARNING_RATE = 0.0001 | ||
WEIGHT_DECAY = 0 | ||
BATCH_SIZE = 64 | ||
SEED = 42 | ||
VERBOSE = 1 | ||
EPOCH = 2 | ||
PATIENCE = 3 | ||
|
||
# Define a Skorch classifier | ||
clf = EEGClassifier( | ||
module=model, | ||
criterion=torch.nn.CrossEntropyLoss, | ||
optimizer=torch.optim.Adam, | ||
optimizer__lr=LEARNING_RATE, | ||
batch_size=BATCH_SIZE, | ||
max_epochs=EPOCH, | ||
train_split=ValidSplit(0.2, random_state=SEED), | ||
callbacks=[ | ||
EarlyStopping(monitor="valid_loss", patience=PATIENCE), | ||
EpochScoring( | ||
scoring="accuracy", on_train=True, name="train_acc", lower_is_better=False | ||
), | ||
EpochScoring( | ||
scoring="accuracy", on_train=False, name="valid_acc", lower_is_better=False | ||
), | ||
InputShapeSetterEEG( | ||
params_list=["in_channels", "input_window_samples", "n_classes"], | ||
), | ||
], | ||
verbose=VERBOSE, # Not printing the results for each epoch | ||
) | ||
|
||
clf.initialize() | ||
|
||
f_params = "./results/Models_CrossSession/001-2014/1/braindecode_EEGInception/EEGInception_fitted_best_model.pkl" | ||
f_optimizer = "./results/Models_CrossSession/001-2014/1/braindecode_EEGInception/EEGInception_fitted_best_optim.pkl" | ||
f_history = "./results/Models_CrossSession/001-2014/1/braindecode_EEGInception/EEGInception_fitted_best_history.json" | ||
|
||
clf.load_params(f_params=f_params, f_optimizer=f_optimizer, f_history=f_history) | ||
|
||
|
||
# Create the dataset | ||
create_dataset = BraindecodeDatasetLoader(drop_last_window=False) | ||
|
||
# Create the pipelines | ||
pipes_pytorch = Pipeline([("Braindecode_dataset", create_dataset), ("EEGInception", clf)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.