diff --git a/Project.toml b/Project.toml index b3a85a5..24b3dca 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "IterationControl" uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" authors = ["Anthony D. Blaom "] -version = "0.3.2" +version = "0.3.3" [deps] EarlyStopping = "792122b4-ca99-40de-a6bc-6742525f08b6" diff --git a/src/IterationControl.jl b/src/IterationControl.jl index 256e3f1..116efbc 100644 --- a/src/IterationControl.jl +++ b/src/IterationControl.jl @@ -40,5 +40,6 @@ include("composite_controls.jl") include("wrapped_controls.jl") include("controls.jl") include("train.jl") +include("square_rooter.jl") end # module diff --git a/src/composite_controls.jl b/src/composite_controls.jl index 94a5127..b175860 100644 --- a/src/composite_controls.jl +++ b/src/composite_controls.jl @@ -1,4 +1,4 @@ -struct CompositeControl{A,B} +struct CompositeControl{A,B} a::A b::B function CompositeControl(a::A, b::B) where {A, B} diff --git a/src/controls.jl b/src/controls.jl index 1b2f877..263c723 100644 --- a/src/controls.jl +++ b/src/controls.jl @@ -1,4 +1,4 @@ -# # TRAIN +# # Step struct Step n::Int @@ -13,13 +13,19 @@ Step(; n=5) = Step(n) body="Train for `n` more iterations. "* "Will never trigger a stop. ") -function update!(c::Step, model, verbosity, args...) - if verbosity > 1 - @info "Stepping model for $(c.n) iterations. " - else - nothing - end +function update!(c::Step, model, verbosity, state=(new_iterations = 0,)) + new_iterations = state.new_iterations + verbosity > 1 && + @info "Stepping model for $(c.n) more iterations. " train!(model, c.n) + state = (new_iterations = new_iterations + c.n,) + return state +end + +@inline function takedown(c::Step, verbosity, state) + verbosity > 1 && + @info "A total of $(state.new_iterations) iterations added. " + return state end # # Info @@ -34,7 +40,7 @@ Info(; f::Function=identity) = Info(f) @create_docs(Info, header="Info(f=identity)", example="Info(my_loss_function)", - body="Log at the `Info` level the value of `f(m)`, "* + body="Log to `Info` the value of `f(m)`, "* "where `m` "* "is the object being iterated. If "* "`IterativeControl.expose(m)` has been overloaded, then "* @@ -44,12 +50,12 @@ Info(; f::Function=identity) = Info(f) "See also [`Warn`](@ref), [`Error`](@ref). ") function update!(c::Info, model, verbosity, args...) - verbosity < 1 || @info _log_eval(c.f, model) + verbosity > 0 && @info _log_eval(c.f, model) return nothing end -# # WARN +# # Warn struct Warn{P<:Function,F<:Union{Function,String}} predicate::P @@ -64,7 +70,7 @@ Warn(predicate; f="") = Warn(predicate, f) example="Warn(m -> length(m.cache) > 100, "* "f=\"Memory low\")", body="If `predicate(m)` is `true`, then "* - "log at the `Warn` level the value of `f` "* + "log to `Warn` the value of `f` "* "(or `f(IterationControl.expose(m))` if `f` is a function). "* "Here `m` "* "is the object being iterated.\n\n"* @@ -72,27 +78,27 @@ Warn(predicate; f="") = Warn(predicate, f) "sufficiently low.\n\n"* "See also [`Info`](@ref), [`Error`](@ref). ") -function update!(c::Warn, model, verbosity, args...) - verbosity > 1 && c.predicate(model) && - @warn _log_eval(c.f, model) - return nothing -end - -function update!(c::Warn, model, verbosity, warnings=()) +function update!(c::Warn, model, verbosity, state=(warnings=(),)) + warnings = state.warnings if c.predicate(model) warning = _log_eval(c.f, model) - verbosity < 1 || @warn warning - state = tuple(warnings..., warning) + verbosity > -1 && @warn warning + newstate = (warnings = tuple(warnings..., warning),) else - state = warnings + newstate = state end - return state + return newstate end -takedown(c::Warn, verbosity, state) = (warnings = state,) +function takedown(c::Warn, verbosity, state) + warnings = join(state.warnings, "\n") + verbosity > 1 && !isempty(warnings) && + @warn "A `Warn` control issued these warnings:\n$warnings" + return state +end -# # ERROR +# # Error struct Error{P<:Function,F<:Union{Function,String}} predicate::P @@ -286,24 +292,28 @@ WithLossDo(; f=x->@info("loss: $x"), kwargs...) = WithLossDo(f, kwargs...) EarlyStopping.needs_loss(::Type{<:WithLossDo}) = true -function update!(c::WithLossDo, model, verbosity, state=(done=false, )) +function update!(c::WithLossDo, + model, + verbosity, + state=(loss=nothing, done=false)) loss = IterationControl.loss(model) r = c.f(loss) done = (c.stop_if_true && r isa Bool && r) ? true : false - return (done=done,) + return (loss=loss, done=done) end done(c::WithLossDo, state) = state.done function takedown(c::WithLossDo, verbosity, state) + verbosity > 1 && @info "final loss: $(state.loss). " if state.done message = c.stop_message === nothing ? "Stop triggered by a `WithLossDo` control. " : c.stop_message verbosity > 0 && @info message - return (done = true, log = message) + return merge(state, (log = message,)) else - return (done = false, log = "") + return merge(state, (log = "",)) end end @@ -340,24 +350,26 @@ EarlyStopping.needs_training_losses(::Type{<:WithTrainingLossesDo}) = true function update!(c::WithTrainingLossesDo, model, verbosity, - state=(done=false, )) + state=(latest_training_loss = nothing, done = false)) losses = IterationControl.training_losses(model) r = c.f(losses) done = (c.stop_if_true && r isa Bool && r) ? true : false - return (done=done, ) + return (latest_training_loss=losses[end], done=done) end done(c::WithTrainingLossesDo, state) = state.done function takedown(c::WithTrainingLossesDo, verbosity, state) + verbosity > 1 && + @info "final training loss: $(state.latest_training_loss). " if state.done message = c.stop_message === nothing ? "Stop triggered by a `WithTrainingLossesDo` control. " : c.stop_message verbosity > 0 && @info message - return (done = true, log = message) + return merge(state, (log = message,)) else - return (done = false, log = "") + return merge(state, (log = "",)) end end @@ -398,13 +410,14 @@ end done(c::WithNumberDo, state) = state.done function takedown(c::WithNumberDo, verbosity, state) + verbosity > 1 && @info "final number: $(state.n). " if state.done message = c.stop_message === nothing ? "Stop triggered by a `WithNumberDo` control. " : c.stop_message verbosity > 0 && @info message - return (done = true, log = message) + return merge(state, (log = message,)) else - return (done = false, log = "") + return merge(state, (log = "",)) end end diff --git a/src/square_rooter.jl b/src/square_rooter.jl new file mode 100644 index 0000000..9337d29 --- /dev/null +++ b/src/square_rooter.jl @@ -0,0 +1,22 @@ +# ## SQUARE ROOTER + +# Consider a model to compute Babylonian approximations to a square root: + +mutable struct SquareRooter + x::Float64 # input - number to be square rooted + root::Float64 # current approximation of root + training_losses::Vector{Float64} # successive approximation differences + SquareRooter(x) = new(x, 1.0, Float64[]) +end + +function IterationControl.train!(m::SquareRooter, Δn::Int) + m.training_losses = Float64[] + for i in 1:Δn + next_guess = (m.root + m.x/m.root)/2 + push!(m.training_losses, abs(next_guess - m.root)) + m.root = next_guess + end +end + +IterationControl.loss(m::SquareRooter) = abs(m.root^2 - m.x) +IterationControl.training_losses(m::SquareRooter) = m.training_losses diff --git a/src/train.jl b/src/train.jl index 1f323d8..3d29e1a 100644 --- a/src/train.jl +++ b/src/train.jl @@ -20,6 +20,15 @@ function train!(model, controls...; verbosity::Int=1) finished = done(control, state) end + # reporting final loss and training loss if available: + loss = IterationControl.loss(model) + training_losses = IterationControl.training_losses(model) + if verbosity > 0 + loss isa Nothing || @info "final loss: $loss" + training_losses isa Nothing || + @info "final training loss: $(training_losses[end])" + end + # finalization: return takedown(control, verbosity, state) end diff --git a/src/wrapped_controls.jl b/src/wrapped_controls.jl index 548173f..a19cc00 100644 --- a/src/wrapped_controls.jl +++ b/src/wrapped_controls.jl @@ -1,3 +1,28 @@ +# # Louder + +struct Louder{C} + control::C + by::Int64 +end + +""" + IterationControl.louder(control, by=1) + +Wrap `control` to make in more (or less) verbose. The same as +`control`, but as if the global `verbosity` were increased by the value +`by`. + +""" +louder(c; by=1) = Louder(c, by) + +# api: +done(d::Louder, state) = done(d.control, state) +update!(d::Louder, model, verbosity, args...) = + update!(d.control, model, verbosity + d.by, args...) +takedown(d::Louder, verbosity, state) = + takedown(d.control, verbosity + d.by, state) + + # # Debug struct Debug{C} diff --git a/test/_models_for_testing.jl b/test/_models_for_testing.jl index b875e83..719c77e 100644 --- a/test/_models_for_testing.jl +++ b/test/_models_for_testing.jl @@ -1,39 +1,5 @@ # # DUMMY MODELS FOR TESTING - -# ## SQUARE ROOTER - -# Consider a model to compute Babylonian approximations to a square root: - -mutable struct SquareRooter - x::Float64 # input - number to be square rooted - root::Float64 # current approximation of root - training_losses::Vector{Float64} # successive approximation differences - SquareRooter(x) = new(x, 1.0, Float64[]) -end - -function IterationControl.train!(m::SquareRooter, Δn::Int) - m.training_losses = Float64[] - for i in 1:Δn - next_guess = (m.root + m.x/m.root)/2 - push!(m.training_losses, abs(next_guess - m.root)) - m.root = next_guess - end -end - -IterationControl.loss(m::SquareRooter) = abs(m.root^2 - m.x) -IterationControl.training_losses(m::SquareRooter) = m.training_losses - -model = SquareRooter(4.0) -IterationControl.train!(model, 1) -@assert model.root ≈ 2.5 -@assert IterationControl.loss(model) ≈ 25/4 - 4 -IterationControl.train!(model, 100) -@assert IterationControl.loss(model) ≈ 0 -@assert IterationControl.training_losses(model)[1:2] ≈ - abs.([41/20 - 5/2, 3281/1640 - 41/20]) - - # ## PARTICLE TRACKER (without methods lifted) # Consider an object that tracks a particle in one dimension, moving, @@ -80,15 +46,17 @@ function ingest!(model::Particle, target) return nothing end -model = Particle() -ingest!(model, 1) -train!(model, 1) -@assert loss(model) ≈ 0.9 -ingest!(model, -0.9) -train!(model, 1) -@assert loss(model) ≈ 0.9 +# temporary testing area + +# model = Particle() +# ingest!(model, 1) +# train!(model, 1) +# @assert loss(model) ≈ 0.9 +# ingest!(model, -0.9) +# train!(model, 1) +# @assert loss(model) ≈ 0.9 -model = Particle() -ingest!(model, 1) -train!(model, 2) -@assert training_losses(model) ≈ [0.9, 0.81] +# model = Particle() +# ingest!(model, 1) +# train!(model, 2) +# @assert training_losses(model) ≈ [0.9, 0.81] diff --git a/test/api.jl b/test/api.jl index 93b7b14..62b566f 100644 --- a/test/api.jl +++ b/test/api.jl @@ -2,7 +2,7 @@ model = Particle() invalid = InvalidValue() @test_throws IC.ERR_TRAIN IterationControl.train!(model) -@test_throws IC.err_train(model) IterationControl.train!(model, 1) +@test_throws IC.err_train(model) IterationControl.train!(model, 1) # lifting train!: IC.train!(model::Particle, n) = train!(model, n) diff --git a/test/controls.jl b/test/controls.jl index 27331f1..33619b3 100644 --- a/test/controls.jl +++ b/test/controls.jl @@ -6,12 +6,13 @@ m = SquareRooter(4) c = Step(n=2) state = IC.update!(c, m, 0) - @test state === nothing + @test state === (new_iterations = 2,) @test m.training_losses == all_training_losses[1:2] - state = IC.update!(c, m, 0) + state = IC.update!(c, m, 0, state) @test m.training_losses == all_training_losses[3:4] @test !IC.done(c, state) - @test IC.takedown(c, 1, state) == NamedTuple() + @test_logs((:info, r"A total of "), + @test IC.takedown(c, 2, state) == (new_iterations = 4,)) end @testset "Info" begin @@ -32,47 +33,50 @@ end c = Warn(m -> m.root > 2.4) IC.train!(m, 1) - @test_logs (:warn, "") IC.update!(c, m, 2) @test_logs (:warn, "") IC.update!(c, m, 1) - state = @test_logs IC.update!(c, m, 0) - @test state === ("", ) + @test_logs (:warn, "") IC.update!(c, m, 0) + state = @test_logs IC.update!(c, m, -1) + @test state === (warnings=("", ),) IC.train!(m, 1) - @test_logs IC.update!(c, m, 2) @test_logs IC.update!(c, m, 1) - state = @test_logs IC.update!(c, m, 0) - @test state === () + @test_logs IC.update!(c, m, 0) + state = @test_logs IC.update!(c, m, -1) + @test state === (warnings=(),) m = SquareRooter(4) IC.train!(m, 1) state = IC.update!(c, m, -1) - @test_logs (:warn, "") IC.update!(c, m, 2, state) @test_logs (:warn, "") IC.update!(c, m, 1, state) - state = @test_logs IC.update!(c, m, 0, state) - @test state === ("", "") + @test_logs (:warn, "") IC.update!(c, m, 0, state) + state = @test_logs IC.update!(c, m, -1, state) + @test state === (warnings=("", ""),) IC.train!(m, 1) - @test_logs IC.update!(c, m, 2, state) + @test_logs IC.update!(c, m, 1, state) @test_logs IC.update!(c, m, 0, state) state = @test_logs IC.update!(c, m, -, state) - @test state === ("", "") + @test state === (warnings=("", ""),) m = SquareRooter(4) c = Warn(m -> m.root > 2.4, f = m->m.root) IC.train!(m, 1) - @test_logs (:warn, 2.5) IC.update!(c, m, 2) @test_logs (:warn, 2.5) IC.update!(c, m, 1) - state = @test_logs IC.update!(c, m, 0) - @test state === (2.5, ) + @test_logs (:warn, 2.5) IC.update!(c, m, 0) + state = @test_logs IC.update!(c, m, -1) + @test state === (warnings=(2.5, ),) - @test_logs (:warn, 2.5) IC.update!(c, m, 2, state) @test_logs (:warn, 2.5) IC.update!(c, m, 1, state) - state = @test_logs IC.update!(c, m, 0, state) - @test state === (2.5, 2.5) + @test_logs (:warn, 2.5) IC.update!(c, m, 0, state) + state = @test_logs IC.update!(c, m, -1, state) + @test state === (warnings=(2.5, 2.5),) @test !IC.done(c, state) - @test IC.takedown(c, 10, state) == (warnings = (2.5, 2.5),) + @test_logs((:warn, r"A `Warn`"), + @test IC.takedown(c, 2, state) == (warnings = (2.5, 2.5),)) + @test_logs @test IC.takedown(c, 1, state) == (warnings = (2.5, 2.5),) + end @testset "Error" begin @@ -181,7 +185,11 @@ end state = IC.update!(c, m, 1, state) @test !state.done @test v ≈ [2.25, (3281/1640)^2 - 4] - @test IC.takedown(c, 0, state) == (done = false, log="") + @test IC.takedown(c, 1, state) == + (loss=v[end], done = false, log="") + @test_logs((:info, r"final loss"), + @test IC.takedown(c, 2, state) == + (loss=v[end], done = false, log="")) v = Float64[] f2(loss) = (push!(v, loss); last(v) < 0.02) @@ -196,8 +204,20 @@ end @test state.done @test v ≈ [2.25, (3281/1640)^2 - 4] @test IC.takedown(c, 0, state) == - (done = true, + (loss = v[end], + done = true, log="Stop triggered by a `WithLossDo` control. ") + @test_logs((:info, r"Stop triggered"), + @test IC.takedown(c, 1, state) == + (loss = v[end], + done = true, + log="Stop triggered by a `WithLossDo` control. ")) + @test_logs((:info, r"final loss"), + (:info, r"Stop triggered"), + @test IC.takedown(c, 2, state) == + (loss = v[end], + done = true, + log="Stop triggered by a `WithLossDo` control. ")) v = Float64[] f3(loss) = (push!(v, loss); last(v) < 0.02) @@ -212,7 +232,8 @@ end @test state.done @test v ≈ [2.25, (3281/1640)^2 - 4] @test IC.takedown(c, 0, state) == - (done = true, + (loss = v[end], + done = true, log="foo") end @@ -231,7 +252,12 @@ end state = IC.update!(c, m, 1, state) @test !state.done @test v ≈ [1.5, 0.45] - @test IC.takedown(c, 0, state) == (done = false, log="") + @test IC.takedown(c, 0, state) == + (latest_training_loss = v[end], done = false, log="") + @test_logs((:info, r"final train"), + @test IC.takedown(c, 2, state) == + (latest_training_loss=v[end], done = false, log="")) + v = Float64[] f1(training_loss) = (push!(v, last(training_loss)); last(v) < 0.5) @@ -246,8 +272,20 @@ end @test state.done @test v ≈ [1.5, 0.45] @test IC.takedown(c, 0, state) == - (done = true, + (latest_training_loss = v[end], + done = true, log="Stop triggered by a `WithTrainingLossesDo` control. ") + @test_logs((:info, r"Stop "), + @test IC.takedown(c, 1, state) == + (latest_training_loss = v[end], + done = true, + log="Stop triggered by a `WithTrainingLossesDo` control. ")) + @test_logs((:info, r"final train"), + (:info, r"Stop"), + @test IC.takedown(c, 2, state) == + (latest_training_loss = v[end], + done = true, + log="Stop triggered by a `WithTrainingLossesDo` control. ")) v = Float64[] f2(training_loss) = (push!(v, last(training_loss)); last(v) < 0.5) @@ -262,7 +300,8 @@ end @test state.done @test v ≈ [1.5, 0.45] @test IC.takedown(c, 0, state) == - (done = true, + (latest_training_loss = v[end], + done = true, log="foo") end @@ -280,7 +319,9 @@ end state = IC.update!(c, m, 1, state) @test !state.done @test v == [1, 2] - @test IC.takedown(c, 0, state) == (done = false, log="") + @test IC.takedown(c, 0, state) == (done = false, n = 2, log="") + @test_logs((:info, r"final number"), + @test IC.takedown(c, 2, state) == (done = false, n = 2, log="")) v = Int[] f2(n) = (push!(v, n); last(n) > 1) @@ -296,7 +337,19 @@ end @test v == [1, 2] @test IC.takedown(c, 0, state) == (done = true, + n= 2, log="Stop triggered by a `WithNumberDo` control. ") + @test_logs((:info, r"Stop"), + @test IC.takedown(c, 1, state) == + (done = true, + n= 2, + log="Stop triggered by a `WithNumberDo` control. ")) + @test_logs((:info, r"final number"), + (:info, r"Stop"), + @test IC.takedown(c, 2, state) == + (done = true, + n= 2, + log="Stop triggered by a `WithNumberDo` control. ")) v = Int[] f3(n) = (push!(v, n); last(n) > 1) @@ -312,6 +365,7 @@ end @test v == [1, 2] @test IC.takedown(c, 0, state) == (done = true, + n = 2, log="foo") end diff --git a/test/runtests.jl b/test/runtests.jl index 6ae3aa6..577b712 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,11 @@ using Test const IC = IterationControl include("_models_for_testing.jl") +const SquareRooter = IC.SquareRooter + +@testset "square_rooter.jl" begin + include("square_rooter.jl") +end @testset "utilities" begin include("utilities.jl") diff --git a/test/square_rooter.jl b/test/square_rooter.jl new file mode 100644 index 0000000..df6401c --- /dev/null +++ b/test/square_rooter.jl @@ -0,0 +1,10 @@ +model = IterationControl.SquareRooter(4.0) +IterationControl.train!(model, 1) +@test model.root ≈ 2.5 +@test IterationControl.loss(model) ≈ 25/4 - 4 +IterationControl.train!(model, 100) +@test IterationControl.loss(model) ≈ 0 +@test IterationControl.training_losses(model)[1:2] ≈ + abs.([41/20 - 5/2, 3281/1640 - 41/20]) + +true diff --git a/test/train.jl b/test/train.jl index f2ac7f1..ceed34a 100644 --- a/test/train.jl +++ b/test/train.jl @@ -1,7 +1,7 @@ @testset "basic integration" begin m = SquareRooter(4) - report = IC.train!(m, Step(2), InvalidValue(), NumberLimit(3); verbosity=0); - @test report[1] == (Step(2), NamedTuple()) + report = IC.train!(m, Step(2), InvalidValue(), NumberLimit(3); verbosity=0); + @test report[1] == (Step(2), (new_iterations = 6,)) @test report[2] == (InvalidValue(), (done=false, log="")) report[3] == (NumberLimit(3), (done=true, @@ -9,12 +9,17 @@ "stopping criterion. ")) m = SquareRooter(4) - @test_logs((:info, r"Stop triggered by Num"), + @test_logs((:info, r"final loss"), + (:info, r"final training loss"), + (:info, r"Stop triggered by Num"), IC.train!(m, Step(2), InvalidValue(), NumberLimit(3))); @test_logs((:info, r"Using these controls"), - (:info, r"Stepping model for 2 iterations"), - (:info, r"Stepping model for 2 iterations"), - (:info, r"Stepping model for 2 iterations"), + (:info, r"Stepping model for 2 more iterations"), + (:info, r"Stepping model for 2 more iterations"), + (:info, r"Stepping model for 2 more iterations"), + (:info, r"final loss"), + (:info, r"final training loss"), + (:info, r"A total of 6 iterations added"), (:info, r"Stop triggered by NumberLimit"), IC.train!(m, Step(2), InvalidValue(), @@ -54,6 +59,20 @@ end Step(1), IterationControl.skip( WithNumberDo(x->push!(numbers, x)), predicate=3), - NumberLimit(10)) + NumberLimit(10), verbosity=0) @test numbers == [1, 2, 3] end + +@testset "integration test related to #38" begin + model = IterationControl.SquareRooter(4) + @test_logs((:info, r"number"), + (:info, r"number"), + (:info, r"final loss"), + (:info, r"final training loss"), + (:info, r"Stop triggered by"), + IC.train!(model, + Step(1), + Threshold(2.1), + WithNumberDo(), + IterationControl.skip(WithLossDo(), predicate=3))) +end diff --git a/test/wrapped_controls.jl b/test/wrapped_controls.jl index 9122a07..cf8a10d 100644 --- a/test/wrapped_controls.jl +++ b/test/wrapped_controls.jl @@ -1,3 +1,24 @@ +@testset "louder" begin + m = SquareRooter(4) + control = Warn(_->true, f="42") + + c = IC.louder(control, by=0) + IC.train!(m, 1) + state = @test_logs IC.update!(c, m, -1) + IC.train!(m, 2) + state = @test_logs (:warn, r"42") IC.update!(c, m, 0, state) + @test_logs IC.takedown(c, 1, state) + @test_logs (:warn, r"A ") IC.takedown(c, 2, state) + + c = IC.louder(control, by=1) + IC.train!(m, 1) + state = @test_logs IC.update!(c, m, -2) + IC.train!(m, 2) + state = @test_logs (:warn, r"42") IC.update!(c, m, -1, state) + @test_logs IC.takedown(c, 0, state) + @test_logs (:warn, r"A ") IC.takedown(c, 1, state) +end + @testset "debug" begin m = SquareRooter(4) test_controls = [Step(2), InvalidValue(), GL(), Callback(println)]