Skip to content

Commit

Permalink
Improve docs for callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed Oct 23, 2018
1 parent cd8ab2e commit 67e45c7
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 1 deletion.
12 changes: 12 additions & 0 deletions docs/autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pyshac
import pyshac.config.hyperparameters as hp
import pyshac.config.data as data
import pyshac.config.callbacks as callbacks
import pyshac.core.engine as optimizer
import pyshac.core.managed.tf_engine as tf_optimizer
import pyshac.core.managed.keras_engine as keras_optimizer
Expand Down Expand Up @@ -65,6 +66,17 @@
(data.Dataset, '*'),
],
},
{
'page': 'config/callbacks.md',
'classes': [
(callbacks.Callback, '*'),
(callbacks.History, []),
(callbacks.CSVLogger, []),
],
'functions': [
callbacks.get_history,
]
},
{
'page': 'core/engine.md',
'classes': [
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pages:
- Config:
- Hyper Parameters: config/hyperparameters.md
- Datasets: config/data.md
- Callbacks: config/callbacks.md
- Core:
- SHAC: core/engine.md
- PyTorch SHAC: core/torch_engine.md
Expand Down
2 changes: 1 addition & 1 deletion pyshac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from pyshac.config.data import Dataset
from pyshac.core.engine import SHAC

__version__ = '0.3.0.5'
__version__ = '0.3.1'
67 changes: 67 additions & 0 deletions pyshac/config/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,78 @@ def set_engine(self, engine):
self.engine = engine

def on_train_begin(self, logs=None):
"""
Called at the beginning of training.
# Arguments
logs (dict | None): dictionary of logs.
"""
pass

def on_train_end(self, logs=None):
"""
Called at the end of training.
# Arguments
logs (dict | None): dictionary of logs.
"""
pass

def on_epoch_begin(self, epoch, logs=None):
"""
Called at the start of an epoch.
# Arguments
epoch (int): index of epoch.
logs (dict | None): dictionary of logs.
"""
pass

def on_epoch_end(self, epoch, logs=None):
"""
Called at the end of an epoch.
# Arguments
epoch (int): index of epoch.
logs (dict | None): dictionary of logs.
"""
pass

def on_evaluation_begin(self, params, logs=None):
"""
Called before the generated parameters are evaluated.
# Arguments:
params (list(OrderedDict)): A list of OrderedDicts,
such that each item is a dictionary of the names
and sampled values of a HyperParemeterList.
logs (dict | None): dictionary of logs.
"""
pass

def on_evaluation_ended(self, evaluations, logs=None):
"""
Called after the generated parameters are evaluated.
# Arguments:
evaluations (list(float)): A list of floating point
values, corresponding to the provided parameter
settings.
logs (dict | None): dictionary of logs.
"""
pass

def on_dataset_changed(self, dataset, logs=None):
"""
Called with the dataset maintained by the engine is
updated with new samples or data.
# Arguments:
dataset (Dataset): A Dataset object which contains
the history of sampled parameters and their
corresponding evaluation values.
logs (dict | None): dictionary of logs.
"""
pass


Expand Down Expand Up @@ -229,10 +283,23 @@ def __init__(self):
super(History, self).__init__()

def on_train_begin(self, logs=None):
"""
Initializes the epoch list and history dictionary.
# Arguments:
logs (dict | None): dictionary of logs.
"""
self.epochs = []
self.history = logs or {}

def on_epoch_end(self, epoch, logs=None):
"""
Adds the current epoch's log values to the history.
# Arguments:
epoch (int): index of epoch.
logs (dict | None): dictionary of logs.
"""
logs = logs or {}
self.epochs.append(epoch)

Expand Down
15 changes: 15 additions & 0 deletions pyshac/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def fit(self, eval_fn, skip_cv_checks=False, early_stop=False, relax_checks=Fals
callbacks (list | None): Optional list of callbacks that are executed when
the engine is being trained. `History` callback is automatically added
for all calls to `fit`.
# Returns:
A `History` object which tracks all the important information
during training, and can be accessed using `history.history`
as a dictionary.
"""
num_epochs = self.total_budget // self.num_workers

Expand Down Expand Up @@ -361,6 +366,11 @@ def fit_dataset(self, dataset_path, skip_cv_checks=False, early_stop=False, pres
required number of samples by the engine.
FileNotFoundError: If the dataset is not available at the
provided filepath.
# Returns:
A `History` object which tracks all the important information
during training, and can be accessed using `history.history`
as a dictionary.
"""
if self.parameters is None:
raise ValueError("Parameter list cannot be `None` when training "
Expand Down Expand Up @@ -1149,6 +1159,11 @@ def fit(self, eval_fn, skip_cv_checks=False, early_stop=False, relax_checks=Fals
callbacks (list | None): Optional list of callbacks that are executed when
the engine is being trained. `History` callback is automatically added
for all calls to `fit`.
# Returns:
A `History` object which tracks all the important information
during training, and can be accessed using `history.history`
as a dictionary.
"""
return super(SHAC, self).fit(eval_fn, skip_cv_checks=skip_cv_checks,
early_stop=early_stop, relax_checks=relax_checks,
Expand Down
5 changes: 5 additions & 0 deletions pyshac/core/managed/keras_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def fit(self, eval_fn, skip_cv_checks=False, early_stop=False, relax_checks=Fals
callbacks (list | None): Optional list of callbacks that are executed when
the engine is being trained. `History` callback is automatically added
for all calls to `fit`.
# Returns:
A `History` object which tracks all the important information
during training, and can be accessed using `history.history`
as a dictionary.
"""
return super(KerasSHAC, self).fit(eval_fn, skip_cv_checks=skip_cv_checks,
early_stop=early_stop, relax_checks=relax_checks,
Expand Down
5 changes: 5 additions & 0 deletions pyshac/core/managed/tf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ def fit(self, eval_fn, skip_cv_checks=False, early_stop=False, relax_checks=Fals
callbacks (list | None): Optional list of callbacks that are executed when
the engine is being trained. `History` callback is automatically added
for all calls to `fit`.
# Returns:
A `History` object which tracks all the important information
during training, and can be accessed using `history.history`
as a dictionary.
"""
return super(TensorflowSHAC, self).fit(eval_fn, skip_cv_checks=skip_cv_checks,
early_stop=early_stop, relax_checks=relax_checks,
Expand Down
5 changes: 5 additions & 0 deletions pyshac/core/managed/torch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def fit(self, eval_fn, skip_cv_checks=False, early_stop=False, relax_checks=Fals
callbacks (list | None): Optional list of callbacks that are executed when
the engine is being trained. `History` callback is automatically added
for all calls to `fit`.
# Returns:
A `History` object which tracks all the important information
during training, and can be accessed using `history.history`
as a dictionary.
"""
return super(TorchSHAC, self).fit(eval_fn, skip_cv_checks=skip_cv_checks,
early_stop=early_stop, relax_checks=relax_checks,
Expand Down

0 comments on commit 67e45c7

Please sign in to comment.