Skip to content

Commit

Permalink
doc improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Mar 12, 2021
1 parent dc3f1e5 commit e7e6635
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 53 deletions.
61 changes: 34 additions & 27 deletions src/controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Step(; n=5) = Step(n)
@create_docs(Step,
header="Step(; n=1)",
example="Step(2)",
body="Train the model for `n` more iterations. "*
body="Train for `n` more iterations. "*
"Will never trigger a stop. ")

function update!(c::Step, model, verbosity, args...)
Expand All @@ -34,11 +34,11 @@ Info(; f::Function=identity) = Info(f)
@create_docs(Info,
header="Info(f=identity)",
example="Info(my_loss_function)",
body="Log at the `Info` level the value of `f(model)`, "*
"where `model` "*
body="Log at the `Info` level the value of `f(m)`, "*
"where `m` "*
"is the object being iterated. If "*
"`IterativeControl.expose(model)` has been overloaded, then "*
"log `f(expose(model))` instead.\n\n"*
"`IterativeControl.expose(m)` has been overloaded, then "*
"log `f(expose(m))` instead.\n\n"*
"Can be suppressed by setting the global verbosity level "*
"sufficiently low. \n\n"*
"See also [`Warn`](@ref), [`Error`](@ref). ")
Expand All @@ -61,12 +61,12 @@ Warn(predicate; f="") = Warn(predicate, f)

@create_docs(Warn,
header="Warn(predicate; f=\"\")",
example="Warn(model -> length(model.cache) > 100, "*
example="Warn(m -> length(m.cache) > 100, "*
"f=\"Memory low\")",
body="If `predicate(model)` is `true`, then "*
body="If `predicate(m)` is `true`, then "*
"log at the `Warn` level the value of `f` "*
"(or `f(IterationControl.expose(model))` if `f` is a function). "*
"Here `model` "*
"(or `f(IterationControl.expose(m))` if `f` is a function). "*
"Here `m` "*
"is the object being iterated.\n\n"*
"Can be suppressed by setting the global verbosity level "*
"sufficiently low.\n\n"*
Expand Down Expand Up @@ -105,12 +105,12 @@ Error(predicate; f="", exception=nothing) = Error(predicate, f, exception)

@create_docs(Error,
header="Error(predicate; f=\"\", exception=nothing))",
example="Error(model -> isnan(model.bias), f=\"Bias overflow!\")",
body="If `predicate(model)` is `true`, then "*
example="Error(m -> isnan(m.bias), f=\"Bias overflow!\")",
body="If `predicate(m)` is `true`, then "*
"log at the `Error` level the value of `f` "*
"(or `f(IterationControl.expose(model))` if `f` is a function) "*
"(or `f(IterationControl.expose(m))` if `f` is a function) "*
"and stop iteration at the end of the current control cycle. "*
"Here `model` "*
"Here `m` "*
"is the object being iterated.\n\n"*
"Specify `exception=...` to throw an immediate "*
"execption, without "*
Expand Down Expand Up @@ -155,9 +155,9 @@ Callback(; f=identity, kwargs...) = Callback(f, kwargs...)
header="Callback(f=_->nothing, stop_if_true=false, "*
"stop_message=nothing, raw=false)",
example="Callback(m->put!(v, my_loss_function(m))",
body="Call `f(IterationControl.expose(model))`, where "*
"`model` is the object being iterated, unless `raw=true`, in "*
"which case call `f(model)` (guaranteed if `expose` has not been "*
body="Call `f(IterationControl.expose(m))`, where "*
"`m` is the object being iterated, unless `raw=true`, in "*
"which case call `f(m)` (guaranteed if `expose` has not been "*
"overloaded.) "*
"If `stop_if_true` is `true`, then trigger an early stop "*
"if the value returned by `f` is `true`, logging the "*
Expand Down Expand Up @@ -198,12 +198,12 @@ Base.show(io::IO, d::Data{S}) where S =
"stop_when_exhausted=$(d.stop_when_exhausted))")

@create_docs(Data,
header="Data(data; stop_when_exhausted=false)",
header="Data(my_data; stop_when_exhausted=false)",
example="Data(rand(100))",
body="In each application of this control a new `item` from the "*
"iterable, `data`, is retrieved (using `iterate`) and "*
"`IterationControl.ingest!(model, item)` is called. Here "*
"`model` is the object being iterated. \n\n"*
"`IterationControl.ingest!(m, item)` is called. Here "*
"`m` is the object being iterated. \n\n"*
"A control becomes passive once the `data` iterable is done. "*
"To trigger "*
"a stop *after one passive application of the control*, set "*
Expand Down Expand Up @@ -271,12 +271,13 @@ end
WithLossDo(f::Function;
stop_if_true=false,
stop_message=nothing) = WithLossDo(f, stop_if_true, stop_message)
WithLossDo(; f=x->@info(x), kwargs...) = WithLossDo(f, kwargs...)
WithLossDo(; f=x->@info("loss: $x"), kwargs...) = WithLossDo(f, kwargs...)

@create_docs(WithLossDo,
header="WithLossDo(f=x->@info(x)), stop_if_true=false, "*
header="WithLossDo(f=x->@info(\"loss: \$x\"), "*
"stop_if_true=false, "*
"stop_message=nothing)",
example="WithLossDo(x->put!(my_losses, x)",
example="WithLossDo(x->put!(my_losses, x))",
body="Call `f(loss)`, where "*
"`loss` is current loss.\n\n"*
"If `stop_if_true` is `true`, then trigger an early stop "*
Expand Down Expand Up @@ -319,10 +320,12 @@ end
WithTrainingLossesDo(f::Function;
stop_if_true=false,
stop_message=nothing) = WithTrainingLossesDo(f, stop_if_true, stop_message)
WithTrainingLossesDo(; f=v->@info(v), kwargs...) = WithTrainingLossesDo(f, kwargs...)
WithTrainingLossesDo(; f=v->@info("training: $v"), kwargs...) =
WithTrainingLossesDo(f, kwargs...)

@create_docs(WithTrainingLossesDo,
header="WithTrainingLossesDo(f=v->@info(v)), stop_if_true=false, "*
header="WithTrainingLossesDo(f=v->@info(\"training: \$v\"), "*
"stop_if_true=false, "*
"stop_message=nothing)",
example="WithTrainingLossesDo(v->put!(my_losses, last(v))",
body="Call `f(training_losses)`, where "*
Expand All @@ -334,7 +337,10 @@ WithTrainingLossesDo(; f=v->@info(v), kwargs...) = WithTrainingLossesDo(f, kwarg

EarlyStopping.needs_training_losses(::Type{<:WithTrainingLossesDo}) = true

function update!(c::WithTrainingLossesDo, model, verbosity, state=(done=false, ))
function update!(c::WithTrainingLossesDo,
model,
verbosity,
state=(done=false, ))
losses = IterationControl.training_losses(model)
r = c.f(losses)
done = (c.stop_if_true && r isa Bool && r) ? true : false
Expand Down Expand Up @@ -368,10 +374,11 @@ end
WithNumberDo(f::Function;
stop_if_true=false,
stop_message=nothing) = WithNumberDo(f, stop_if_true, stop_message)
WithNumberDo(; f=n->@info(n), kwargs...) = WithNumberDo(f, kwargs...)
WithNumberDo(; f=n->@info("number: $n"), kwargs...) = WithNumberDo(f, kwargs...)

@create_docs(WithNumberDo,
header="WithNumberDo(f=n->@info(n)), stop_if_true=false, "*
header="WithNumberDo(f=n->@info(\"number: \$n\"), "*
"stop_if_true=false, "*
"stop_message=nothing)",
example="WithNumberDo(n->put!(my_channel, n))",
body="Call `f(n)`, where "*
Expand Down
28 changes: 14 additions & 14 deletions src/wrapped_controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,49 +43,49 @@ _pred(predicate::Int) = t -> mod(t + 1, predicate) == 0
An iteration control wrapper.
If `predicate` is an **integer**, `n`: Apply `control` on every `n`
calls to apply the wrapper, starting with the `n`th call.
If `predicate` is an **integer**, `k`: Apply `control` on every `k`
calls to apply the wrapper, starting with the `k`th call.
If `predicate` is a **function**: Apply `control` as usual when
`predicate(t + 1)` is `true` but otherwise skip. Here `t` is the
number of calls to apply the wrapper so far.
`predicate(n + 1)` is `true` but otherwise skip. Here `n` is the
number of calls to apply the wrapped control so far.
"""
skip(control; predicate::Int=1) = Skip(control, _pred(predicate))

_state(s, model, verbosity, t) = if s.predicate(t)
_state(s, model, verbosity, n) = if s.predicate(n)
atomic_state = update!(s.control, model, verbosity + 1)
return (atomic_state = atomic_state, t = t + 1)
return (atomic_state = atomic_state, n = n + 1)
else
return nothing
end

function update!(s::Skip, model, verbosity)
state_candidate = _state(s, model, verbosity, 0)
state_candidate isa Nothing && return (t = 1, )
state_candidate isa Nothing && return (n = 1, )
return state_candidate
end

# in case atomic state is not initialized in first `update` call:
function update!(s::Skip, model, verbosity, state::NamedTuple{(:t,)})
state_candidate = _state(s, model, verbosity, state.t)
state_candidate isa Nothing && return (t = state.t + 1, )
function update!(s::Skip, model, verbosity, state::NamedTuple{(:n,)})
state_candidate = _state(s, model, verbosity, state.n)
state_candidate isa Nothing && return (n = state.n + 1, )
return state_candidate
end

# regular update:
function update!(s::Skip, model, verbosity, state)
state_candidate = _state(s, model, verbosity, state.t)
state_candidate = _state(s, model, verbosity, state.n)
state_candidate isa Nothing &&
return (atomic_state = state.atomic_state, t = state.t + 1)
return (atomic_state = state.atomic_state, n = state.n + 1)
return state_candidate
end

done(s::Skip, state) = done(s.control, state.atomic_state)

# can't be done if atomic state never intialized:
done(s::Skip, state::NamedTuple{(:t,)}) = false
done(s::Skip, state::NamedTuple{(:n,)}) = false

takedown(s::Skip, verbosity, state) =
takedown(s.control, verbosity, state.atomic_state)
takedown(::Skip, ::Any, ::NamedTuple{(:t,)}) = NamedTuple()
takedown(::Skip, ::Any, ::NamedTuple{(:n,)}) = NamedTuple()
21 changes: 13 additions & 8 deletions test/controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,19 @@ end

model = Particle(0.1)
losses = Float64[]
report = IC.train!(model,
Data(data, stop_when_exhausted=true),
Step(5),
Threshold(0.01),
TimeLimit(0.0005),
Info(loss),
Callback(callback!),
verbosity=-1)
noise = fill((:info, r""), 33)
report = @test_logs(noise...,
IC.train!(model,
Data(data, stop_when_exhausted=true),
Step(5),
WithNumberDo(),
WithLossDo(),
WithTrainingLossesDo(),
Threshold(0.01),
TimeLimit(0.0005),
Info(loss),
Callback(callback!),
verbosity=-1))
@test length(losses) == length(data) + 1
@test loss(model) > 0.01
end
8 changes: 4 additions & 4 deletions test/wrapped_controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ end
@test !s.predicate(2)
atomic_state = IC.update!(c, m, 0)
state = IC.update!(s, m, 0)
@test state == (t = 1, )
@test state == (n = 1, )
state = IC.update!(s, m, 0, state)
@test state == (atomic_state = atomic_state, t = 2)
@test state == (atomic_state = atomic_state, n = 2)
state = IC.update!(s, m, 0, state)
@test state == (atomic_state = atomic_state, t = 3)
@test state == (atomic_state = atomic_state, n = 3)
atomic_state = IC.update!(c, m, 0, atomic_state)
state = IC.update!(s, m, 0, state)
@test state == (atomic_state = atomic_state, t = 4)
@test state == (atomic_state = atomic_state, n = 4)
@test IC.done(c, atomic_state) == IC.done(s, state)
@test IC.takedown(c, 0, atomic_state) == IC.takedown(s, 0, state)
end)
Expand Down

0 comments on commit e7e6635

Please sign in to comment.