Skip to content

Commit

Permalink
Merge pull request #29 from ablaom/dev
Browse files Browse the repository at this point in the history
For a 0.2.0 release. Some name changes
  • Loading branch information
ablaom authored Mar 12, 2021
2 parents c03b65d + c9df35d commit 1c8690f
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 110 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "IterationControl"
uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.3"
version = "0.2.0"

[deps]
EarlyStopping = "792122b4-ca99-40de-a6bc-6742525f08b6"
Expand Down
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,22 +167,22 @@ control | description
`Callback(f=_->nothing)`| Call `f(model)` |`train!` | yes |
`TimeLimit(t=0.5)` | Stop after `t` hours |`train!` | yes |
`NumberLimit(n=100)` | Stop after `n` control cycles |`train!` | yes |
`NumberCount(f=n->@info(n))` | Call `f(n)` where `n` is the control cycle count |`train!` | yes |
`Loss(f=x->@info(x))` | Call `f(loss)` where `loss` is the current loss |`train!`, `loss` | yes |
`TrainingLosses(f=v->@info(v))`| Call `f(v)` where `v` is the current batch of training losses |`train!`, `training_loss` | yes |
`WithNumberDo(f=n->@info(n))` | Call `f(n)` where `n` is the control cycle count |`train!` | yes |
`WithLossDo(f=x->@info(x))` | Call `f(loss)` where `loss` is the current loss |`train!`, `loss` | yes |
`WithTrainingLossesDo(f=v->@info(v))`| Call `f(v)` where `v` is the current batch of training losses |`train!`, `training_loss` | yes |
`NotANumber()` | Stop when `NaN` encountered |`train!`, `loss` | yes |
`Threshold(value=0.0)` | Stop when `loss < value` |`train!`, `loss` | yes |
`GL(alpha=2.0)` | Stop after "Generalization Loss" exceeds `alpha` |`train!`, `loss` | yes | ``GL_α``
`GL(alpha=2.0)` | Stop after "Generalization WithLossDo" exceeds `alpha` |`train!`, `loss` | yes | ``GL_α``
`Patience(n=5)` | Stop after `n` consecutive loss increases |`train!`, `loss` | yes | ``UP_s``
`PQ(alpha=0.75, k=5)` | Stop after "Progress-modified GL" exceeds `alpha` |`train!`, `loss`, `training_losses`| yes | ``PQ_α``
`Data(data)` | Call `ingest!(model, item)` on the next `item` in the iterable `data`. |`train!`, `ingest!` | yes |
> Table 1. Atomic controls
**Stopping option.** All the following controls trigger a stop if the
provided function `f` returns `true` and `stop_if_true=true` is
specified in the constructor: `Callback`, `NumberCount`, `Loss`,
`TrainingLosses`.
> Table 1. Atomic controls
specified in the constructor: `Callback`, `WithNumberDo`, `WithLossDo`,
`WithTrainingLossesDo`.
There are also three control wrappers to modify a control's behavior:
Expand Down
6 changes: 3 additions & 3 deletions src/IterationControl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ const CONTROLS = [:Step,
:Warn,
:Error,
:Callback,
:Loss,
:TrainingLosses,
:NumberCount,
:WithLossDo,
:WithTrainingLossesDo,
:WithNumberDo,
:Data]
for criterion in subtypes(StoppingCriterion)
control = split(string(criterion), ".") |> last |> Symbol
Expand Down
121 changes: 64 additions & 57 deletions src/controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Step(; n=5) = Step(n)
@create_docs(Step,
header="Step(; n=1)",
example="Step(2)",
body="Train the model for `n` more iterations. "*
body="Train for `n` more iterations. "*
"Will never trigger a stop. ")

function update!(c::Step, model, verbosity, args...)
Expand All @@ -34,11 +34,11 @@ Info(; f::Function=identity) = Info(f)
@create_docs(Info,
header="Info(f=identity)",
example="Info(my_loss_function)",
body="Log at the `Info` level the value of `f(model)`, "*
"where `model` "*
body="Log at the `Info` level the value of `f(m)`, "*
"where `m` "*
"is the object being iterated. If "*
"`IterativeControl.expose(model)` has been overloaded, then "*
"log `f(expose(model))` instead.\n\n"*
"`IterativeControl.expose(m)` has been overloaded, then "*
"log `f(expose(m))` instead.\n\n"*
"Can be suppressed by setting the global verbosity level "*
"sufficiently low. \n\n"*
"See also [`Warn`](@ref), [`Error`](@ref). ")
Expand All @@ -61,12 +61,12 @@ Warn(predicate; f="") = Warn(predicate, f)

@create_docs(Warn,
header="Warn(predicate; f=\"\")",
example="Warn(model -> length(model.cache) > 100, "*
example="Warn(m -> length(m.cache) > 100, "*
"f=\"Memory low\")",
body="If `predicate(model)` is `true`, then "*
body="If `predicate(m)` is `true`, then "*
"log at the `Warn` level the value of `f` "*
"(or `f(IterationControl.expose(model))` if `f` is a function). "*
"Here `model` "*
"(or `f(IterationControl.expose(m))` if `f` is a function). "*
"Here `m` "*
"is the object being iterated.\n\n"*
"Can be suppressed by setting the global verbosity level "*
"sufficiently low.\n\n"*
Expand Down Expand Up @@ -105,12 +105,12 @@ Error(predicate; f="", exception=nothing) = Error(predicate, f, exception)

@create_docs(Error,
header="Error(predicate; f=\"\", exception=nothing))",
example="Error(model -> isnan(model.bias), f=\"Bias overflow!\")",
body="If `predicate(model)` is `true`, then "*
example="Error(m -> isnan(m.bias), f=\"Bias overflow!\")",
body="If `predicate(m)` is `true`, then "*
"log at the `Error` level the value of `f` "*
"(or `f(IterationControl.expose(model))` if `f` is a function) "*
"(or `f(IterationControl.expose(m))` if `f` is a function) "*
"and stop iteration at the end of the current control cycle. "*
"Here `model` "*
"Here `m` "*
"is the object being iterated.\n\n"*
"Specify `exception=...` to throw an immediate "*
"execption, without "*
Expand Down Expand Up @@ -155,9 +155,9 @@ Callback(; f=identity, kwargs...) = Callback(f, kwargs...)
header="Callback(f=_->nothing, stop_if_true=false, "*
"stop_message=nothing, raw=false)",
example="Callback(m->put!(v, my_loss_function(m))",
body="Call `f(IterationControl.expose(model))`, where "*
"`model` is the object being iterated, unless `raw=true`, in "*
"which case call `f(model)` (guaranteed if `expose` has not been "*
body="Call `f(IterationControl.expose(m))`, where "*
"`m` is the object being iterated, unless `raw=true`, in "*
"which case call `f(m)` (guaranteed if `expose` has not been "*
"overloaded.) "*
"If `stop_if_true` is `true`, then trigger an early stop "*
"if the value returned by `f` is `true`, logging the "*
Expand Down Expand Up @@ -198,12 +198,12 @@ Base.show(io::IO, d::Data{S}) where S =
"stop_when_exhausted=$(d.stop_when_exhausted))")

@create_docs(Data,
header="Data(data; stop_when_exhausted=false)",
header="Data(my_data; stop_when_exhausted=false)",
example="Data(rand(100))",
body="In each application of this control a new `item` from the "*
"iterable, `data`, is retrieved (using `iterate`) and "*
"`IterationControl.ingest!(model, item)` is called. Here "*
"`model` is the object being iterated. \n\n"*
"`IterationControl.ingest!(m, item)` is called. Here "*
"`m` is the object being iterated. \n\n"*
"A control becomes passive once the `data` iterable is done. "*
"To trigger "*
"a stop *after one passive application of the control*, set "*
Expand Down Expand Up @@ -259,45 +259,46 @@ function takedown(c::Data, verbosity, state)
end


# # Loss
# # WithLossDo

struct Loss{F<:Function}
struct WithLossDo{F<:Function}
f::F
stop_if_true::Bool
stop_message::Union{String,Nothing}
end

# constructor:
Loss(f::Function;
WithLossDo(f::Function;
stop_if_true=false,
stop_message=nothing) = Loss(f, stop_if_true, stop_message)
Loss(; f=x->@info(x), kwargs...) = Loss(f, kwargs...)
stop_message=nothing) = WithLossDo(f, stop_if_true, stop_message)
WithLossDo(; f=x->@info("loss: $x"), kwargs...) = WithLossDo(f, kwargs...)

@create_docs(Loss,
header="Loss(f=x->@info(x)), stop_if_true=false, "*
@create_docs(WithLossDo,
header="WithLossDo(f=x->@info(\"loss: \$x\"), "*
"stop_if_true=false, "*
"stop_message=nothing)",
example="Loss(x->put!(my_losses, x)",
example="WithLossDo(x->put!(my_losses, x))",
body="Call `f(loss)`, where "*
"`loss` is current loss.\n\n"*
"If `stop_if_true` is `true`, then trigger an early stop "*
"if the value returned by `f` is `true`, logging the "*
"`stop_message` if specified. ")

EarlyStopping.needs_loss(::Type{<:Loss}) = true
EarlyStopping.needs_loss(::Type{<:WithLossDo}) = true

function update!(c::Loss, model, verbosity, state=(done=false, ))
function update!(c::WithLossDo, model, verbosity, state=(done=false, ))
loss = IterationControl.loss(model)
r = c.f(loss)
done = (c.stop_if_true && r isa Bool && r) ? true : false
return (done=done,)
end

done(c::Loss, state) = state.done
done(c::WithLossDo, state) = state.done

function takedown(c::Loss, verbosity, state)
function takedown(c::WithLossDo, verbosity, state)
if state.done
message = c.stop_message === nothing ?
"Stop triggered by a `Loss` control. " :
"Stop triggered by a `WithLossDo` control. " :
c.stop_message
verbosity > 0 && @info message
return (done = true, log = message)
Expand All @@ -307,46 +308,51 @@ function takedown(c::Loss, verbosity, state)
end


# # TrainingLosses
# # WithTrainingLossesDo

struct TrainingLosses{F<:Function}
struct WithTrainingLossesDo{F<:Function}
f::F
stop_if_true::Bool
stop_message::Union{String,Nothing}
end

# constructor:
TrainingLosses(f::Function;
WithTrainingLossesDo(f::Function;
stop_if_true=false,
stop_message=nothing) = TrainingLosses(f, stop_if_true, stop_message)
TrainingLosses(; f=v->@info(v), kwargs...) = TrainingLosses(f, kwargs...)
stop_message=nothing) = WithTrainingLossesDo(f, stop_if_true, stop_message)
WithTrainingLossesDo(; f=v->@info("training: $v"), kwargs...) =
WithTrainingLossesDo(f, kwargs...)

@create_docs(TrainingLosses,
header="TrainingLosses(f=v->@info(v)), stop_if_true=false, "*
@create_docs(WithTrainingLossesDo,
header="WithTrainingLossesDo(f=v->@info(\"training: \$v\"), "*
"stop_if_true=false, "*
"stop_message=nothing)",
example="TrainingLosses(v->put!(my_losses, last(v))",
example="WithTrainingLossesDo(v->put!(my_losses, last(v))",
body="Call `f(training_losses)`, where "*
"`training_losses` is the vector of most recent batch "*
"of training losses.\n\n"*
"If `stop_if_true` is `true`, then trigger an early stop "*
"if the value returned by `f` is `true`, logging the "*
"`stop_message` if specified. ")

EarlyStopping.needs_training_losses(::Type{<:TrainingLosses}) = true
EarlyStopping.needs_training_losses(::Type{<:WithTrainingLossesDo}) = true

function update!(c::TrainingLosses, model, verbosity, state=(done=false, ))
function update!(c::WithTrainingLossesDo,
model,
verbosity,
state=(done=false, ))
losses = IterationControl.training_losses(model)
r = c.f(losses)
done = (c.stop_if_true && r isa Bool && r) ? true : false
return (done=done, )
end

done(c::TrainingLosses, state) = state.done
done(c::WithTrainingLossesDo, state) = state.done

function takedown(c::TrainingLosses, verbosity, state)
function takedown(c::WithTrainingLossesDo, verbosity, state)
if state.done
message = c.stop_message === nothing ?
"Stop triggered by a `TrainingLosses` control. " :
"Stop triggered by a `WithTrainingLossesDo` control. " :
c.stop_message
verbosity > 0 && @info message
return (done = true, log = message)
Expand All @@ -356,44 +362,45 @@ function takedown(c::TrainingLosses, verbosity, state)
end


# # NumberCount
# # WithNumberDo

struct NumberCount{F<:Function}
struct WithNumberDo{F<:Function}
f::F
stop_if_true::Bool
stop_message::Union{String,Nothing}
end

# constructor:
NumberCount(f::Function;
WithNumberDo(f::Function;
stop_if_true=false,
stop_message=nothing) = NumberCount(f, stop_if_true, stop_message)
NumberCount(; f=n->@info(n), kwargs...) = NumberCount(f, kwargs...)
stop_message=nothing) = WithNumberDo(f, stop_if_true, stop_message)
WithNumberDo(; f=n->@info("number: $n"), kwargs...) = WithNumberDo(f, kwargs...)

@create_docs(NumberCount,
header="NumberCount(f=n->@info(n)), stop_if_true=false, "*
@create_docs(WithNumberDo,
header="WithNumberDo(f=n->@info(\"number: \$n\"), "*
"stop_if_true=false, "*
"stop_message=nothing)",
example="NumberCount(n->put!(my_channel, n))",
example="WithNumberDo(n->put!(my_channel, n))",
body="Call `f(n)`, where "*
"`n` is one more than the number of previous applications "*
"of the control (so, `n = 1, 2, 3, ...`).\n\n"*
"If `stop_if_true` is `true`, then trigger an early stop "*
"if the value returned by `f` is `true`, logging the "*
"`stop_message` if specified. ")

function update!(c::NumberCount, model, verbosity, state=(done = false, n = 0))
function update!(c::WithNumberDo, model, verbosity, state=(done = false, n = 0))
n = state.n
r = c.f(state.n + 1)
done = (c.stop_if_true && r isa Bool && r) ? true : false
return (done = done, n = n + 1)
end

done(c::NumberCount, state) = state.done
done(c::WithNumberDo, state) = state.done

function takedown(c::NumberCount, verbosity, state)
function takedown(c::WithNumberDo, verbosity, state)
if state.done
message = c.stop_message === nothing ?
"Stop triggered by a `NumberCount` control. " :
"Stop triggered by a `WithNumberDo` control. " :
c.stop_message
verbosity > 0 && @info message
return (done = true, log = message)
Expand Down
28 changes: 14 additions & 14 deletions src/wrapped_controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,49 +43,49 @@ _pred(predicate::Int) = t -> mod(t + 1, predicate) == 0
An iteration control wrapper.
If `predicate` is an **integer**, `n`: Apply `control` on every `n`
calls to apply the wrapper, starting with the `n`th call.
If `predicate` is an **integer**, `k`: Apply `control` on every `k`
calls to apply the wrapper, starting with the `k`th call.
If `predicate` is a **function**: Apply `control` as usual when
`predicate(t + 1)` is `true` but otherwise skip. Here `t` is the
number of calls to apply the wrapper so far.
`predicate(n + 1)` is `true` but otherwise skip. Here `n` is the
number of calls to apply the wrapped control so far.
"""
skip(control; predicate::Int=1) = Skip(control, _pred(predicate))

_state(s, model, verbosity, t) = if s.predicate(t)
_state(s, model, verbosity, n) = if s.predicate(n)
atomic_state = update!(s.control, model, verbosity + 1)
return (atomic_state = atomic_state, t = t + 1)
return (atomic_state = atomic_state, n = n + 1)
else
return nothing
end

function update!(s::Skip, model, verbosity)
state_candidate = _state(s, model, verbosity, 0)
state_candidate isa Nothing && return (t = 1, )
state_candidate isa Nothing && return (n = 1, )
return state_candidate
end

# in case atomic state is not initialized in first `update` call:
function update!(s::Skip, model, verbosity, state::NamedTuple{(:t,)})
state_candidate = _state(s, model, verbosity, state.t)
state_candidate isa Nothing && return (t = state.t + 1, )
function update!(s::Skip, model, verbosity, state::NamedTuple{(:n,)})
state_candidate = _state(s, model, verbosity, state.n)
state_candidate isa Nothing && return (n = state.n + 1, )
return state_candidate
end

# regular update:
function update!(s::Skip, model, verbosity, state)
state_candidate = _state(s, model, verbosity, state.t)
state_candidate = _state(s, model, verbosity, state.n)
state_candidate isa Nothing &&
return (atomic_state = state.atomic_state, t = state.t + 1)
return (atomic_state = state.atomic_state, n = state.n + 1)
return state_candidate
end

done(s::Skip, state) = done(s.control, state.atomic_state)

# can't be done if atomic state never intialized:
done(s::Skip, state::NamedTuple{(:t,)}) = false
done(s::Skip, state::NamedTuple{(:n,)}) = false

takedown(s::Skip, verbosity, state) =
takedown(s.control, verbosity, state.atomic_state)
takedown(::Skip, ::Any, ::NamedTuple{(:t,)}) = NamedTuple()
takedown(::Skip, ::Any, ::NamedTuple{(:n,)}) = NamedTuple()
Loading

0 comments on commit 1c8690f

Please sign in to comment.