Skip to content

Commit

Permalink
Merge pull request #51 from JuliaAI/needs-training-losses
Browse files Browse the repository at this point in the history
Add `needs_loss` and `needs_training_losses` control traits
  • Loading branch information
ablaom authored Nov 12, 2021
2 parents 08d1089 + fe416ae commit 1bff2cc
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 130 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
name = "IterationControl"
uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.5.0"
version = "0.5.1"

[deps]
EarlyStopping = "792122b4-ca99-40de-a6bc-6742525f08b6"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[compat]
EarlyStopping = "0.2"
EarlyStopping = "0.3"
julia = "1"

[extras]
Expand Down
80 changes: 48 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ package:
using IterationControl
IterationControl.train!(model::SquareRooter, n) = train!(model, n) # lifting
```
By definitiion, the lifted `train!` has the same functionality as the original one:
By definition, the lifted `train!` has the same functionality as the original one:

```julia
model = SquareRooter(9)
Expand All @@ -76,7 +76,7 @@ julia> IterationControl.train!(model, Step(2), NumberLimit(3), Info(m->m.root));
Here each control is repeatedly applied in sequence until one of them
triggers a stop. The first control `Step(2)` says, "Train the model
two more iterations"; the second asks, "Have I been applied 3 times
yet?", signalling a stop (at the end of the current control cycle) if
yet?", signaling a stop (at the end of the current control cycle) if
so; and the third logs the value of the function `m -> m.root`,
evaluated on `model`, to `Info`. In this example only the second
control can terminate model iteration.
Expand All @@ -85,7 +85,7 @@ If `model` admits a method returning a loss (in this case the
difference between `x` and the square of `root`) then we can lift
that method to `IterationControl.loss` to enable control using
loss-based stopping criteria, such as a loss threshold. In the
demonstation below, we also include a callback:
demonstration below, we also include a callback:
```julia
model = SquareRooter(4)
Expand All @@ -100,9 +100,9 @@ losses = Float64[]
callback(model) = push!(losses, loss(model))

julia> IterationControl.train!(model,
Step(1),
Threshold(0.0001),
Callback(callback));
Step(1),
Threshold(0.0001),
Callback(callback));
[ Info: Stop triggered by Threshold(0.0001) stopping criterion.

julia> losses
Expand All @@ -111,7 +111,7 @@ julia> losses
3.716891878724482e-7
```
In many appliations to machine learning, "loss" will be an
In many applications to machine learning, "loss" will be an
out-of-sample loss, computed after some iterations. If `model`
additionally generates user-inspectable "training losses" (one per
iteration) then similarly lifting the appropriate access function to
Expand Down Expand Up @@ -169,12 +169,12 @@ julia> last(reports[2])
julia> last(reports[2]).loss
0.1417301038062284
```
## Controls provided
Controls are repeatedly applied in sequence until a control triggers a
stop. Each control type has a detailed doc-string. sBelow is a short
stop. Each control type has a detailed doc-string. Below is a short
summary, with some advanced options omitted.
control | description | enabled if these are overloaded | can trigger a stop | notation in Prechelt
Expand Down Expand Up @@ -216,8 +216,6 @@ wrapper | description
> Table 2. Wrapped controls
## Access to model through a wrapper
Note that functions ordinarily applied to `model` by some control
Expand All @@ -229,10 +227,10 @@ appropriately overloaded.
## Implementing new controls
There is no abstract control type; any object can be a
control. Behaviour is implemented using a functional style interface
with four methods. Only the first two are compulsory (the `done` and
`takedown` fallbacks always return `false` and `NamedTuple()`
respectively.):
control. Behavior is implemented using a functional style interface
with six methods. Only the first two are compulsory (the fallbacks for
`done`, `takedown`, `needs_loss` and `needs_training_losses` always
return `false` and `NamedTuple()` respectively.):
```julia
update!(control, model, verbosity, n) -> state # initialization
Expand All @@ -242,31 +240,49 @@ takedown(control, verbosity, state) -> human_readable_named_tuple
```
Here `n` is the control cycle count, i.e., one more than the the
number of completed control cylcles.
number of completed control cycles.
If it is nonsensical to apply `control` to any model for which
`loss(model)` has not been overloaded, and we want an error thrown
when this is attempted, then declare `needs_loss(control::MyControl) =
true` to take value true. Otherwise `control` is applied anyway, and
`loss`, if called, returns `nothing`.
Here's how `IterationControl.train!` calls these methods:
A second trait `needs_training_losses(control)` serves an analogous
purpose for training losses.
Here's a simplified version of how `IterationControl.train!` calls
these methods:
```julia
function train!(model, controls...; verbosity::Int=1)

control = composite(controls...)

# before training:
verbosity > 1 && @info "Using these controls: $(flat(control)). "
control = composite(controls...)

# first training event:
n = 1 # counts control cycles
state = update!(control, model, verbosity, n)
finished = done(control, state)
# before training:
verbosity > 1 && @info "Using these controls: $(flat(control)). "

# subsequent training events:
while !finished
n += 1
state = update!(control, model, verbosity, n, state)
finished = done(control, state)
# first training event:
n = 1 # counts control cycles
state = update!(control, model, verbosity, n)
finished = done(control, state)

# checks that model supports control:
if needs_loss(control) && loss(model) === nothing
throw(ERR_NEEDS_LOSS)
end
if needs_training_losses(control) && training_losses(model) === nothing
throw(ERR_NEEDS_TRAINING_LOSSES)
end

# subsequent training events:
while !finished
n += 1
state = update!(control, model, verbosity, n, state)
finished = done(control, state)
end

# finalization:
return takedown(control, verbosity, state)
# finalization:
return takedown(control, verbosity, state)
end
```
21 changes: 17 additions & 4 deletions src/api.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -------------------------------------------------------------------
# # MACHINE METHODS
# # MODEL METHODS

# *models* are externally defined objects with certain functionality
# that is exposed by overloading these methods:
Expand All @@ -21,7 +21,7 @@ err_ingest(model) =
train!(model, Δn) = throw(err_train(model))


# ## REQUIRED FOR SOME CONTROL
# ## REQUIRED FOR SOME CONTROLS

# extract a single numerical estimate of `models`'s performance such
# as an out-of-sample loss; smaller is understood to be better:
Expand Down Expand Up @@ -51,7 +51,7 @@ expose(model, ::Val{false}) = expose(model)
# # CONTROL METHODS

# compulsory: `update!`
# optional: `done`, `takedown`
# optional: `done`, `takedown`, `needs_training_losses`

# called after first training event; returns initialized control
# "state":
Expand All @@ -63,10 +63,23 @@ update!(control, model, verbosity, n, state) = state
# should we stop?
done(control, state) = false

# What to do after this control, or some other control, has triggered
# What to do after `control`, or some other control, has triggered
# a stop. Returns user-inspectable outcomes associated with the
# control's applications (separate from logging). This should be a
# named tuple, except for composite controls which return a tuple of
# named tuples (see composite_controls.jl):
takedown(control, verbosity, state) = NamedTuple()

# If it is nonsensical to apply `control` to any model for which
# `training_losses(model)` has not been overloaded and we want an
# error thrown when this is attempted, then overload this trait to
# take the value `true`. Otherwise `control` is applied anyway, and
# `training_losses`, if called, returns `nothing`.
needs_training_losses(control) = false

# If it is nonsensical to apply `control` to any model for which
# `loss(model)` has not been overloaded and we want an error thrown
# when this is attempted, then overload this trait to take the value
# `true`. Otherwise `control` is applied anyway, and
# `loss`, if called, returns `nothing`.
needs_loss(control) = false
21 changes: 15 additions & 6 deletions src/composite_controls.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct CompositeControl{A,B}
struct CompositeControl{A,B}
a::A
b::B
function CompositeControl(a::A, b::B) where {A, B}
Expand Down Expand Up @@ -29,7 +29,7 @@ update!(c::CompositeControl, m, v, n, state) =
b = update!(c.b, m, v, n, state.b))


## RECURSION TO FLATTEN A CONTROL OR ITS STATE
# # RECURSION TO FLATTEN A CONTROL OR ITS STATE

flat(state) = (state,)
flat(state::NamedTuple{(:a,:b)}) = tuple(flat(state.a)..., flat(state.b)...)
Expand All @@ -42,27 +42,27 @@ _in(c::Any, d::CompositeControl) = c in flat(d)
_in(::CompositeControl, ::Any) = false


## DISPLAY
# # DISPLAY

function Base.show(io::IO, c::CompositeControl)
list = join(string.(flat(c)), ", ")
print(io, "CompositeControl($list)")
end


## RECURSION TO DEFINE `done`
# # RECURSION TO DEFINE `done`

# fallback for atomic controls:
_done(control, state, old_done) = old_done || done(control, state)

%
# composite:
_done(c::CompositeControl, state, old_done) =
_done(c.a, state.a, _done(c.b, state.b, old_done))

done(c::CompositeControl, state) = _done(c, state, false)


## RECURSION TO DEFINE `takedown`
# # RECURSION TO DEFINE `takedown`

# fallback for atomic controls:
function _takedown(control, v, state, old_takedown)
Expand All @@ -76,3 +76,12 @@ _takedown(c::CompositeControl, v, state, old_takedown) =
_takedown(c.a, v, state.a, old_takedown))

takedown(c::CompositeControl, v, state) = _takedown(c, v, state, ())


# # TRAITS

for ex in [:needs_loss, :needs_training_losses]
quote
$ex(c::CompositeControl) = any($ex, flat(c))
end |> eval
end
4 changes: 2 additions & 2 deletions src/controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ WithLossDo(; f=x->@info("loss: $x"), kwargs...) = WithLossDo(f, kwargs...)
"if the value returned by `f` is `true`, logging the "*
"`stop_message` if specified. ")

EarlyStopping.needs_loss(::Type{<:WithLossDo}) = true
needs_loss(::WithLossDo) = true

function update!(c::WithLossDo,
model,
Expand Down Expand Up @@ -347,7 +347,7 @@ WithTrainingLossesDo(; f=v->@info("training: $v"), kwargs...) =
"if the value returned by `f` is `true`, logging the "*
"`stop_message` if specified. ")

EarlyStopping.needs_training_losses(::Type{<:WithTrainingLossesDo}) = true
needs_training_losses(::WithTrainingLossesDo) = true

function update!(c::WithTrainingLossesDo,
model,
Expand Down
Loading

0 comments on commit 1bff2cc

Please sign in to comment.