Skip to content

Commit

Permalink
[Feature Addition] Add callback mechanism and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed Oct 23, 2018
1 parent 4361e50 commit cd8ab2e
Show file tree
Hide file tree
Showing 7 changed files with 680 additions and 17 deletions.
4 changes: 2 additions & 2 deletions examples/tensorflow/mini_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,6 @@ def evaluate_model(session, worker_id, parameters):

samples = shac.predict(5)

# generated a batch of 5 classifiers which can now be evaluated
# generated a batch of 5 samples which can now be evaluated
for i, sample in enumerate(samples):
print("Sample", i + 1, ": ", sample)
print("Sample", i + 1, ": ", sample)
325 changes: 325 additions & 0 deletions pyshac/config/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
""" Callback module is adapted from the Keras library """
import six
import csv
import io
import os

import numpy as np
from collections import Iterable
from collections import OrderedDict

# compatible with both Python 2 and 3
try:
FileNotFoundError
except NameError:
FileNotFoundError = IOError


def get_history(callbacks):
"""
Gets the History callback from a list of callbacks.
# Argumetns:
callbacks (list | CallbackList): a list of callbacks
# Returns:
A History callback object or None if it was not found.
"""
history = None

for c in callbacks:
if isinstance(c, History):
history = c
break

return history


class Callback(object):
"""
Abstract base class used to build new callbacks.
# Properties
engine: instance of a PySHAC Engine.
Reference of the model being trained.
The `logs` dictionary that callback methods
take as argument will contain keys for quantities relevant to
the current batch or epoch.
"""
def __init__(self):
self.engine = None

def set_engine(self, engine):
"""
Sets an instance of a PySHAC Engine.
# Arguments:
engine (AbstractSHAC): A concrete implementation of the
SHAC engine.
"""
self.engine = engine

def on_train_begin(self, logs=None):
pass

def on_train_end(self, logs=None):
pass

def on_epoch_begin(self, epoch, logs=None):
pass

def on_epoch_end(self, epoch, logs=None):
pass

def on_evaluation_begin(self, params, logs=None):
pass

def on_evaluation_ended(self, evaluations, logs=None):
pass

def on_dataset_changed(self, dataset, logs=None):
pass


class CallbackList(Callback):
"""
Container abstracting a list of callbacks.
Automatically creates a History callback if not provided in
list of callbacks.
# Arguments
callbacks (list | None): List of `Callback` instances.
"""
def __init__(self, callbacks=None):
super(CallbackList, self).__init__()

callbacks = callbacks or []

# check if list has History callback in it
history = get_history(callbacks)
if history is None:
callbacks.append(History())

self.callbacks = [c for c in callbacks]

def append(self, callback):
"""
Appends an additional callback to the callback list.
# Arguments:
callback (Callback):
"""
self.callbacks.append(callback)

def set_engine(self, engine):
"""
Sets an instance of a PySHAC Engine.
# Arguments:
engine (AbstractSHAC): A concrete implementation of the
SHAC engine.
"""
for callback in self.callbacks:
callback.set_engine(engine)

def on_train_begin(self, logs=None):
"""
Called at the beginning of training.
# Arguments
logs (dict | None): dictionary of logs.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_train_begin(logs)

def on_train_end(self, logs=None):
"""
Called at the end of training.
# Arguments
logs (dict | None): dictionary of logs.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_train_end(logs)

def on_epoch_begin(self, epoch, logs=None):
"""
Called at the start of an epoch.
# Arguments
epoch (int): index of epoch.
logs (dict | None): dictionary of logs.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)

def on_epoch_end(self, epoch, logs=None):
"""
Called at the end of an epoch.
# Arguments
epoch (int): index of epoch.
logs (dict | None): dictionary of logs.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)

def on_evaluation_begin(self, params, logs=None):
"""
Called before the generated parameters are evaluated.
# Arguments:
params (list(OrderedDict)): A list of OrderedDicts,
such that each item is a dictionary of the names
and sampled values of a HyperParemeterList.
logs (dict | None): dictionary of logs.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_evaluation_begin(params, logs)

def on_evaluation_ended(self, evaluations, logs=None):
"""
Called after the generated parameters are evaluated.
# Arguments:
evaluations (list(float)): A list of floating point
values, corresponding to the provided parameter
settings.
logs (dict | None): dictionary of logs.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_evaluation_ended(evaluations, logs)

def on_dataset_changed(self, dataset, logs=None):
"""
Called with the dataset maintained by the engine is
updated with new samples or data.
# Arguments:
dataset (Dataset): A Dataset object which contains
the history of sampled parameters and their
corresponding evaluation values.
logs (dict | None): dictionary of logs.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_dataset_changed(dataset, logs)

def __iter__(self):
return iter(self.callbacks)


class History(Callback):
"""
Callback that records events into a `History` object.
This callback is automatically applied to
every engine. The `History` object
gets returned by the `fit` or `fit_dataset` methods.
"""
def __init__(self):
super(History, self).__init__()

def on_train_begin(self, logs=None):
self.epochs = []
self.history = logs or {}

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs.append(epoch)

for k, v in logs.items():
self.history.setdefault(k, []).append(v)


class CSVLogger(Callback):
"""
Callback that streams epoch results to a csv file.
Supports all values that can be represented as a string,
including 1D iterables such as np.ndarray.
# Example
```python
csv_logger = CSVLogger('training.log')
shac.fit(evaluation_function, callbacks=[csv_logger])
```
# Arguments
filename (str): filename of the csv file, e.g. 'run/log.csv'.
separator (str): string used to separate elements in the csv file.
append (bool): True: append if file exists (useful for continuing
training). False: overwrite existing file,
"""
def __init__(self, filename, separator=',', append=False):
super(CSVLogger, self).__init__()
self.sep = separator
self.filename = filename
self.append = append
self.writer = None
self.keys = None
self.append_header = True
if six.PY2:
self.file_flags = 'b'
self._open_args = {}
else:
self.file_flags = ''
self._open_args = {'newline': '\n'}

def on_train_begin(self, logs=None):
if self.append:
if os.path.exists(self.filename):
with open(self.filename, 'r' + self.file_flags) as f:
self.append_header = not bool(len(f.readline()))
mode = 'a'
else:
mode = 'w'
self.csv_file = io.open(self.filename,
mode + self.file_flags,
**self._open_args)

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}

def handle_value(k):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, six.string_types):
return k
elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
return '"[%s]"' % (', '.join(map(str, k)))
else:
return k

if self.keys is None:
self.keys = sorted(logs.keys())

if not self.writer:
class CustomDialect(csv.excel):
delimiter = self.sep

fieldnames = ['epoch'] + self.keys

self.writer = csv.DictWriter(self.csv_file,
fieldnames=fieldnames,
dialect=CustomDialect)

if self.append_header:
self.writer.writeheader()

row_dict = OrderedDict({'epoch': epoch})
row_dict.update((key, handle_value(logs[key])) for key in self.keys)
self.writer.writerow(row_dict)
self.csv_file.flush()

def on_train_end(self, logs=None):
self.csv_file.close()
self.writer = None
Loading

0 comments on commit cd8ab2e

Please sign in to comment.