From db400820bf00ec3adb9f972067d5d4b2f613206f Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 12 Nov 2021 10:55:55 +1300 Subject: [PATCH 1/4] needs_loss,needs training_losses <- EarlyStopping; fix logic Needs EarlyStopping.jl version 0.3 - Adds the traits `needs_loss`, `needs_training_losses` that were removed from EarlyStopping. - Changes the logic about how these are handled to correct previously unexpected behaviour discovered in MLJIteration.jl; see [this issue](https://github.com/JuliaAI/MLJIteration.jl/issues/36). The traits have a slightly new interpretation: have `train!` throw an error if the trait is `true` for some control provided and the model has not been overloaded for `loss`/`training_losses`. update readme --- README.md | 69 +++++++++++++++++--------------- src/api.jl | 21 ++++++++-- src/composite_controls.jl | 21 +++++++--- src/controls.jl | 4 +- src/stopping_controls.jl | 81 +++++++++++++------------------------- src/train.jl | 29 ++++++++++---- test/api.jl | 17 ++++---- test/composite_controls.jl | 20 ++++++++++ test/stopping_controls.jl | 17 +------- test/train.jl | 9 +++++ 10 files changed, 161 insertions(+), 127 deletions(-) diff --git a/README.md b/README.md index 9521825..865d98c 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ package: using IterationControl IterationControl.train!(model::SquareRooter, n) = train!(model, n) # lifting ``` -By definitiion, the lifted `train!` has the same functionality as the original one: +By definition, the lifted `train!` has the same functionality as the original one: ```julia model = SquareRooter(9) @@ -76,7 +76,7 @@ julia> IterationControl.train!(model, Step(2), NumberLimit(3), Info(m->m.root)); Here each control is repeatedly applied in sequence until one of them triggers a stop. The first control `Step(2)` says, "Train the model two more iterations"; the second asks, "Have I been applied 3 times -yet?", signalling a stop (at the end of the current control cycle) if +yet?", signaling a stop (at the end of the current control cycle) if so; and the third logs the value of the function `m -> m.root`, evaluated on `model`, to `Info`. In this example only the second control can terminate model iteration. @@ -85,7 +85,7 @@ If `model` admits a method returning a loss (in this case the difference between `x` and the square of `root`) then we can lift that method to `IterationControl.loss` to enable control using loss-based stopping criteria, such as a loss threshold. In the -demonstation below, we also include a callback: +demonstration below, we also include a callback: ```julia model = SquareRooter(4) @@ -100,9 +100,9 @@ losses = Float64[] callback(model) = push!(losses, loss(model)) julia> IterationControl.train!(model, - Step(1), - Threshold(0.0001), - Callback(callback)); + Step(1), + Threshold(0.0001), + Callback(callback)); [ Info: Stop triggered by Threshold(0.0001) stopping criterion. julia> losses @@ -111,7 +111,7 @@ julia> losses 3.716891878724482e-7 ``` -In many appliations to machine learning, "loss" will be an +In many applications to machine learning, "loss" will be an out-of-sample loss, computed after some iterations. If `model` additionally generates user-inspectable "training losses" (one per iteration) then similarly lifting the appropriate access function to @@ -169,12 +169,12 @@ julia> last(reports[2]) julia> last(reports[2]).loss 0.1417301038062284 ``` - + ## Controls provided Controls are repeatedly applied in sequence until a control triggers a -stop. Each control type has a detailed doc-string. sBelow is a short +stop. Each control type has a detailed doc-string. Below is a short summary, with some advanced options omitted. control | description | enabled if these are overloaded | can trigger a stop | notation in Prechelt @@ -216,8 +216,6 @@ wrapper | description > Table 2. Wrapped controls - - ## Access to model through a wrapper Note that functions ordinarily applied to `model` by some control @@ -229,10 +227,10 @@ appropriately overloaded. ## Implementing new controls There is no abstract control type; any object can be a -control. Behaviour is implemented using a functional style interface -with four methods. Only the first two are compulsory (the `done` and -`takedown` fallbacks always return `false` and `NamedTuple()` -respectively.): +control. Behavior is implemented using a functional style interface +with six methods. Only the first two are compulsory (the fallbacks for +`done`, `takedown`, `needs_loss` and `needs_training_losses` always +return `false` and `NamedTuple()` respectively.): ```julia update!(control, model, verbosity, n) -> state # initialization @@ -242,31 +240,40 @@ takedown(control, verbosity, state) -> human_readable_named_tuple ``` Here `n` is the control cycle count, i.e., one more than the the -number of completed control cylcles. +number of completed control cycles. + +If it is nonsensical to apply `control` to any model for which +`loss(model)` has not been overloaded, and we want an error thrown +when this is attempted, then declare `needs_loss(control::MyControl) = +true` to take value true. Otherwise `control` is applied anyway, and +`loss`, if called, returns `nothing`. + +A second trait `needs_training_losses(control)` serves an analogous +purpose for training losses. Here's how `IterationControl.train!` calls these methods: ```julia function train!(model, controls...; verbosity::Int=1) - control = composite(controls...) + control = composite(controls...) - # before training: - verbosity > 1 && @info "Using these controls: $(flat(control)). " + # before training: + verbosity > 1 && @info "Using these controls: $(flat(control)). " - # first training event: - n = 1 # counts control cycles - state = update!(control, model, verbosity, n) - finished = done(control, state) + # first training event: + n = 1 # counts control cycles + state = update!(control, model, verbosity, n) + finished = done(control, state) - # subsequent training events: - while !finished - n += 1 - state = update!(control, model, verbosity, n, state) - finished = done(control, state) - end + # subsequent training events: + while !finished + n += 1 + state = update!(control, model, verbosity, n, state) + finished = done(control, state) + end - # finalization: - return takedown(control, verbosity, state) + # finalization: + return takedown(control, verbosity, state) end ``` diff --git a/src/api.jl b/src/api.jl index d980786..746144e 100644 --- a/src/api.jl +++ b/src/api.jl @@ -1,5 +1,5 @@ # ------------------------------------------------------------------- -# # MACHINE METHODS +# # MODEL METHODS # *models* are externally defined objects with certain functionality # that is exposed by overloading these methods: @@ -21,7 +21,7 @@ err_ingest(model) = train!(model, Δn) = throw(err_train(model)) -# ## REQUIRED FOR SOME CONTROL +# ## REQUIRED FOR SOME CONTROLS # extract a single numerical estimate of `models`'s performance such # as an out-of-sample loss; smaller is understood to be better: @@ -51,7 +51,7 @@ expose(model, ::Val{false}) = expose(model) # # CONTROL METHODS # compulsory: `update!` -# optional: `done`, `takedown` +# optional: `done`, `takedown`, `needs_training_losses` # called after first training event; returns initialized control # "state": @@ -63,10 +63,23 @@ update!(control, model, verbosity, n, state) = state # should we stop? done(control, state) = false -# What to do after this control, or some other control, has triggered +# What to do after `control`, or some other control, has triggered # a stop. Returns user-inspectable outcomes associated with the # control's applications (separate from logging). This should be a # named tuple, except for composite controls which return a tuple of # named tuples (see composite_controls.jl): takedown(control, verbosity, state) = NamedTuple() +# If it is nonsensical to apply `control` to any model for which +# `training_losses(model)` has not been overloaded and we want an +# error thrown when this is attempted, then overload this trait to +# take the value `true`. Otherwise `control` is applied anyway, and +# `training_losses`, if called, returns `nothing`. +needs_training_losses(control) = false + +# If it is nonsensical to apply `control` to any model for which +# `loss(model)` has not been overloaded and we want an error thrown +# when this is attempted, then overload this trait to take the value +# `true`. Otherwise `control` is applied anyway, and +# `loss`, if called, returns `nothing`. +needs_loss(control) = false diff --git a/src/composite_controls.jl b/src/composite_controls.jl index 3a2a2e6..f674849 100644 --- a/src/composite_controls.jl +++ b/src/composite_controls.jl @@ -1,4 +1,4 @@ -struct CompositeControl{A,B} +struct CompositeControl{A,B} a::A b::B function CompositeControl(a::A, b::B) where {A, B} @@ -29,7 +29,7 @@ update!(c::CompositeControl, m, v, n, state) = b = update!(c.b, m, v, n, state.b)) -## RECURSION TO FLATTEN A CONTROL OR ITS STATE +# # RECURSION TO FLATTEN A CONTROL OR ITS STATE flat(state) = (state,) flat(state::NamedTuple{(:a,:b)}) = tuple(flat(state.a)..., flat(state.b)...) @@ -42,7 +42,7 @@ _in(c::Any, d::CompositeControl) = c in flat(d) _in(::CompositeControl, ::Any) = false -## DISPLAY +# # DISPLAY function Base.show(io::IO, c::CompositeControl) list = join(string.(flat(c)), ", ") @@ -50,11 +50,11 @@ function Base.show(io::IO, c::CompositeControl) end -## RECURSION TO DEFINE `done` +# # RECURSION TO DEFINE `done` # fallback for atomic controls: _done(control, state, old_done) = old_done || done(control, state) - +% # composite: _done(c::CompositeControl, state, old_done) = _done(c.a, state.a, _done(c.b, state.b, old_done)) @@ -62,7 +62,7 @@ _done(c::CompositeControl, state, old_done) = done(c::CompositeControl, state) = _done(c, state, false) -## RECURSION TO DEFINE `takedown` +# # RECURSION TO DEFINE `takedown` # fallback for atomic controls: function _takedown(control, v, state, old_takedown) @@ -76,3 +76,12 @@ _takedown(c::CompositeControl, v, state, old_takedown) = _takedown(c.a, v, state.a, old_takedown)) takedown(c::CompositeControl, v, state) = _takedown(c, v, state, ()) + + +# # TRAITS + +for ex in [:needs_loss, :needs_training_losses] + quote + $ex(c::CompositeControl) = any($ex, flat(c)) + end |> eval +end diff --git a/src/controls.jl b/src/controls.jl index 94bd111..5750b13 100644 --- a/src/controls.jl +++ b/src/controls.jl @@ -291,7 +291,7 @@ WithLossDo(; f=x->@info("loss: $x"), kwargs...) = WithLossDo(f, kwargs...) "if the value returned by `f` is `true`, logging the "* "`stop_message` if specified. ") -EarlyStopping.needs_loss(::Type{<:WithLossDo}) = true +needs_loss(::WithLossDo) = true function update!(c::WithLossDo, model, @@ -347,7 +347,7 @@ WithTrainingLossesDo(; f=v->@info("training: $v"), kwargs...) = "if the value returned by `f` is `true`, logging the "* "`stop_message` if specified. ") -EarlyStopping.needs_training_losses(::Type{<:WithTrainingLossesDo}) = true +needs_training_losses(::WithTrainingLossesDo) = true function update!(c::WithTrainingLossesDo, model, diff --git a/src/stopping_controls.jl b/src/stopping_controls.jl index 712cf37..867fe6a 100644 --- a/src/stopping_controls.jl +++ b/src/stopping_controls.jl @@ -2,72 +2,49 @@ # `StoppingCriterion`objects are defined in EarlyStopping.jl +# non-wrapping stopping criteria that are nonsensical to apply if +# `IterationControl.training_losses(model)` is not overloaded: +const ATOMIC_CRIITERIA_NEEDING_TRAINING_LOSSES = [:PQ, ] +const ATOMIC_CRIITERIA_NEEDING_LOSS = [:Threshold, :GL, :PQ, :Patience] -# ## LOSS GETTERS +# stopping criterion that wrap a single stopping criterion (must have +# `:criterion` as a field): +const EARLY_STOPPING_WRAPPERS = [:Warmup, ] -# `get_loss(control, model)` throws an error if control needs -# `IC.loss` overloaded for `type(model)` and it has not been so -# overloaded. If `control` does not need `IC.loss`, then `nothing` is -# returned. In the other cases, the sought after loss is -# returned. `get_training_losses` is similarly defined. -err_getter(c, f, model) = - ArgumentError("Use of `$c` control here requires that "* - "`IterationControl.$f(model)` be "* - "overloaded for `typeof(model)=$(typeof(model))`. ") - -for f in [:loss, :training_losses] - g = Symbol(string(:get_, f)) - t = Symbol(string(:needs_, f)) - fstr = string(f) - eval(quote - $g(c, model) = $g(c, model, Val(ES.$t(c))) - $g(c, model, ::Val{false}) = nothing - @inline function $g(c, model, ::Val{true}) - it = $f(model) - it isa Nothing && throw(err_getter(c, $fstr, model)) - return it - end - end) +for ex in ATOMIC_CRIITERIA_NEEDING_LOSS + quote + needs_loss(::$ex) = true + end |> eval end +for ex in ATOMIC_CRIITERIA_NEEDING_TRAINING_LOSSES + quote + needs_training_losses(::$ex) = true + end |> eval +end -# ## API IMPLEMENTATION - -function update!(c::StoppingCriterion, - model, - verbosity, - n) - _loss = get_loss(c, model) - _training_losses = get_training_losses(c, model) - if _training_losses === nothing || isempty(_training_losses) - state = ES.update(c, _loss) - else # first consume all training losses, then update! loss: - state = ES.update_training(c, first(_training_losses)) - for tloss in _training_losses[2:end] - state = ES.update_training(c, tloss, state) - end - state = ES.update(c, _loss, state) - end - return state +for ex in EARLY_STOPPING_WRAPPERS + quote + needs_loss(wrapper::$ex) = + needs_loss(wrapper.criterion) + needs_training_losses(wrapper::$ex) = + needs_training(wrapper.criterion) + end |> eval end -# regular update!: function update!(c::StoppingCriterion, model, verbosity, - n, - state) - _loss = get_loss(c, model) - _training_losses = get_training_losses(c, model) - if _training_losses === nothing || isempty(_training_losses) - state = ES.update(c, _loss, state) - else # first consume all training losses, then update! loss: + n, state=nothing) + _loss = loss(model) + _training_losses = training_losses(model) + if _training_losses !== nothing && !isempty(_training_losses) for tloss in _training_losses state = ES.update_training(c, tloss, state) end - state = ES.update(c, _loss, state) end + state = ES.update(c, _loss, state) return state end @@ -80,5 +57,3 @@ function takedown(c::StoppingCriterion, verbosity, state) return (done = false, log = "") end end - - diff --git a/src/train.jl b/src/train.jl index 2c54544..3b291b9 100644 --- a/src/train.jl +++ b/src/train.jl @@ -1,11 +1,18 @@ const ERR_TRAIN = ArgumentError("`IterationControl.train!` needs at "* - "least two arguments. ") + "least two arguments. ") + +const ERR_NEEDS_LOSS = ArgumentError( + "Encountered a control that needs losses but no losses found. ") + +const ERR_NEEDS_TRAINING_LOSSES = ArgumentError( + "Encountered a control that needs training losses but no training "* + "losses found. ") function train!(model, controls...; verbosity::Int=1) isempty(controls) && throw(ERR_TRAIN) - control = CompositeControl(controls...) + control = composite(controls...) # before training: verbosity > 1 && @info "Using these controls: $(flat(control)). " @@ -15,6 +22,14 @@ function train!(model, controls...; verbosity::Int=1) state = update!(control, model, verbosity, n) finished = done(control, state) + # checks that model supports control: + if needs_loss(control) && loss(model) === nothing + throw(ERR_NEEDS_LOSS) + end + if needs_training_losses(control) && training_losses(model) === nothing + throw(ERR_NEEDS_TRAINING_LOSSES) + end + # subsequent training events: while !finished n += 1 @@ -23,12 +38,12 @@ function train!(model, controls...; verbosity::Int=1) end # reporting final loss and training loss if available: - loss = IterationControl.loss(model) - training_losses = IterationControl.training_losses(model) + _loss = IterationControl.loss(model) + _training_losses = IterationControl.training_losses(model) if verbosity > 0 - loss isa Nothing || @info "final loss: $loss" - training_losses isa Nothing || isempty(training_losses) || - @info "final training loss: $(training_losses[end])" + _loss isa Nothing || @info "final loss: $_loss" + _training_losses isa Nothing || isempty(_training_losses) || + @info "final training loss: $(_training_losses[end])" verbosity > 1 && @info "total control cycles: $n" end diff --git a/test/api.jl b/test/api.jl index 62b566f..1670713 100644 --- a/test/api.jl +++ b/test/api.jl @@ -1,5 +1,6 @@ model = Particle() -invalid = InvalidValue() +withloss = WithLossDo(x-> nothing) +step = Step(1) @test_throws IC.ERR_TRAIN IterationControl.train!(model) @test_throws IC.err_train(model) IterationControl.train!(model, 1) @@ -7,21 +8,21 @@ invalid = InvalidValue() # lifting train!: IC.train!(model::Particle, n) = train!(model, n) -@test_throws(IC.err_getter(invalid, :loss, model), - IC.train!(model, invalid, NumberLimit(1))) +@test_throws(IC.ERR_NEEDS_LOSS, + IC.train!(model, step, withloss, NumberLimit(1))) # lifting loss!: IterationControl.loss(m::Particle) = loss(m) -IC.train!(model, invalid, NumberLimit(1), verbosity=0) +@test_logs IC.train!(model, step, withloss, NumberLimit(1), verbosity=0) -@test_throws(IC.err_getter(PQ(), :training_losses, model), - IC.train!(model, PQ(), NumberLimit(1))) +@test_throws(IC.ERR_NEEDS_TRAINING_LOSSES, + IC.train!(model, step, PQ(), NumberLimit(1))) # lifting training_losses: IterationControl.training_losses(m::Particle) = training_losses(m) -IC.train!(model, Step(2), PQ(), NumberLimit(1), verbosity=0) +@test_logs IC.train!(model, Step(2), PQ(), NumberLimit(1), verbosity=0) @test_throws(IC.err_ingest(model), IC.train!(model, Data(1:2), NumberLimit(1))) @@ -29,4 +30,4 @@ IC.train!(model, Step(2), PQ(), NumberLimit(1), verbosity=0) #lifting ingest!: IC.ingest!(model::Particle, datum) = ingest!(model, datum) -IC.train!(model, Data(1:1), NumberLimit(1), verbosity=0); +@test_logs IC.train!(model, Step(2), Data(1:1), NumberLimit(1), verbosity=0) diff --git a/test/composite_controls.jl b/test/composite_controls.jl index 873892e..5b2cf76 100644 --- a/test/composite_controls.jl +++ b/test/composite_controls.jl @@ -47,7 +47,27 @@ end @test done_d == done_a || done_b || done_c report_d = IC.takedown(d, 0, state_d2) @test report_d == ((a, report_a), (b, report_b), (c, report_c)) +end + +@testset "traits" begin + needs_nothing = NumberLimit(4) + needs_loss = Threshold(0.01) + needs_training = WithTrainingLossesDo() + needs_both = PQ() + + @test !IC.needs_loss(IC.composite(needs_nothing)) + @test !IC.needs_loss(IC.composite(needs_nothing, needs_training)) + @test IC.needs_loss(IC.composite(needs_nothing, needs_loss)) + @test IC.needs_loss(IC.composite(needs_nothing, needs_both)) + @test IC.needs_loss(IC.composite(needs_loss)) + @test IC.needs_loss(IC.composite(needs_loss, needs_nothing)) + @test !IC.needs_training_losses(IC.composite(needs_nothing)) + @test IC.needs_training_losses(IC.composite(needs_nothing, needs_training)) + @test !IC.needs_training_losses(IC.composite(needs_nothing, needs_loss)) + @test IC.needs_training_losses(IC.composite(needs_nothing, needs_both)) + @test IC.needs_training_losses(IC.composite(needs_training)) + @test !IC.needs_training_losses(IC.composite(needs_loss, needs_nothing)) end true diff --git a/test/stopping_controls.jl b/test/stopping_controls.jl index b215f46..8024860 100644 --- a/test/stopping_controls.jl +++ b/test/stopping_controls.jl @@ -1,20 +1,6 @@ -@testset "loss getters" begin - model=SquareRooter(4) - @test IterationControl.get_loss(InvalidValue(), model) == - IterationControl.loss(model) - @test_throws(IterationControl.err_getter(InvalidValue(), :loss, :junk), - IterationControl.get_loss(InvalidValue(), :junk)) - IterationControl.train!(model, 2) - @test IterationControl.get_training_losses(PQ(), model) == - IterationControl.training_losses(model) - @test_throws( - IterationControl.err_getter(PQ(), :training_losses, :junk), - IterationControl.get_training_losses(PQ(), :junk)) -end - @testset "stopping criteria as controls" begin - # A stopping criterion than ignores training losses: + # A stopping criterion that ignores training losses: m = SquareRooter(4) c = NumberLimit(2) @@ -56,5 +42,4 @@ end report = IC.takedown(c, 1, state) @test !report.done @test report.log == "" - end diff --git a/test/train.jl b/test/train.jl index e5743ff..43cac34 100644 --- a/test/train.jl +++ b/test/train.jl @@ -109,3 +109,12 @@ end s.training_losses == reverse(tlosses[j-L+1:j]) end |> all end + +# https://github.com/JuliaAI/MLJIteration.jl/issues/36 +@testset "integration test related to MLJIteration.jl #36" begin + model = SquareRooter(NaN) + report = + IC.train!(model, Step(1), InvalidValue(), NumberLimit(3), verbosity=0) + @test report[1][2].new_iterations == 1 + @test occursin("Stopping", report[2][2].log) +end From b3c49b240c11b454e7247f6f88725cafbd1e5b67 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 12 Nov 2021 11:23:26 +1300 Subject: [PATCH 2/4] bump EarlyStopping compat = "0.3" --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5c18d6b..9a70d27 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,7 @@ EarlyStopping = "792122b4-ca99-40de-a6bc-6742525f08b6" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [compat] -EarlyStopping = "0.2" +EarlyStopping = "0.3" julia = "1" [extras] From b51691c2eea7ccda59be270d5152d5611bb23fe1 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 12 Nov 2021 11:30:03 +1300 Subject: [PATCH 3/4] update readme --- README.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 865d98c..1068b92 100644 --- a/README.md +++ b/README.md @@ -251,7 +251,8 @@ true` to take value true. Otherwise `control` is applied anyway, and A second trait `needs_training_losses(control)` serves an analogous purpose for training losses. -Here's how `IterationControl.train!` calls these methods: +Here's a simplified version of how `IterationControl.train!` calls +these methods: ```julia function train!(model, controls...; verbosity::Int=1) @@ -265,6 +266,14 @@ function train!(model, controls...; verbosity::Int=1) n = 1 # counts control cycles state = update!(control, model, verbosity, n) finished = done(control, state) + + # checks that model supports control: + if needs_loss(control) && loss(model) === nothing + throw(ERR_NEEDS_LOSS) + end + if needs_training_losses(control) && training_losses(model) === nothing + throw(ERR_NEEDS_TRAINING_LOSSES) + end # subsequent training events: while !finished From fe416ae758ec341d45e81f9995387b375f3cac1b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 12 Nov 2021 12:13:07 +1300 Subject: [PATCH 4/4] bump 0.5.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9a70d27..091f6aa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "IterationControl" uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" authors = ["Anthony D. Blaom "] -version = "0.5.0" +version = "0.5.1" [deps] EarlyStopping = "792122b4-ca99-40de-a6bc-6742525f08b6"