diff --git a/Project.toml b/Project.toml index 889b85f..8366021 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "IterationControl" uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" authors = ["Anthony D. Blaom "] -version = "0.5.2" +version = "0.5.3" [deps] EarlyStopping = "792122b4-ca99-40de-a6bc-6742525f08b6" diff --git a/src/controls.jl b/src/controls.jl index 5231860..aee2699 100644 --- a/src/controls.jl +++ b/src/controls.jl @@ -152,11 +152,11 @@ struct Callback{F<:Function} end # constructor: -Callback(f::Function; +Callback(; f=identity, stop_if_true=false, stop_message=nothing, raw=false) = Callback(f, stop_if_true, stop_message, raw) -Callback(; f=identity, kwargs...) = Callback(f, kwargs...) +Callback(f; kwargs...) = Callback(; f=f, kwargs...) @create_docs(Callback, header="Callback(f=_->nothing, stop_if_true=false, "* @@ -275,10 +275,10 @@ struct WithLossDo{F<:Function} end # constructor: -WithLossDo(f::Function; - stop_if_true=false, - stop_message=nothing) = WithLossDo(f, stop_if_true, stop_message) -WithLossDo(; f=x->@info("loss: $x"), kwargs...) = WithLossDo(f, kwargs...) +WithLossDo(; f=x->@info("loss: $x"), + stop_if_true=false, + stop_message=nothing) = WithLossDo(f, stop_if_true, stop_message) +WithLossDo(f; kwargs...) = WithLossDo(; f=f, kwargs...) @create_docs(WithLossDo, header="WithLossDo(f=x->@info(\"loss: \$x\"), "* @@ -330,11 +330,12 @@ struct WithTrainingLossesDo{F<:Function} end # constructor: -WithTrainingLossesDo(f::Function; - stop_if_true=false, - stop_message=nothing) = WithTrainingLossesDo(f, stop_if_true, stop_message) -WithTrainingLossesDo(; f=v->@info("training: $v"), kwargs...) = - WithTrainingLossesDo(f, kwargs...) +WithTrainingLossesDo(; f=v->@info("training: $v"), + stop_if_true=false, + stop_message=nothing) = + WithTrainingLossesDo(f, stop_if_true, stop_message) +WithTrainingLossesDo(f; kwargs...) = + WithTrainingLossesDo(; f=f, kwargs...) @create_docs(WithTrainingLossesDo, header="WithTrainingLossesDo(f=v->@info(\"training: \$v\"), "* @@ -388,10 +389,10 @@ struct WithNumberDo{F<:Function} end # constructor: -WithNumberDo(f::Function; - stop_if_true=false, - stop_message=nothing) = WithNumberDo(f, stop_if_true, stop_message) -WithNumberDo(; f=n->@info("number: $n"), kwargs...) = WithNumberDo(f, kwargs...) +WithNumberDo(; f=n->@info("number: $n"), + stop_if_true=false, + stop_message=nothing) = WithNumberDo(f, stop_if_true, stop_message) +WithNumberDo(f; kwargs...) = WithNumberDo(; f=f, kwargs...) @create_docs(WithNumberDo, header="WithNumberDo(f=n->@info(\"number: \$n\"), "* diff --git a/test/controls.jl b/test/controls.jl index 319478c..8867e48 100644 --- a/test/controls.jl +++ b/test/controls.jl @@ -441,3 +441,28 @@ end @test length(losses) == length(data) + 1 @test loss(model) > 0.01 end + +@testset "constructors #55" begin + for Control in [WithTrainingLossesDo, + Callback, + WithLossDo, + WithNumberDo] + g(x) = true + + c = Control(g) + @test c.f == g + @test !c.stop_if_true + + c= Control(g, stop_if_true=true) + @test c.f == g + @test c.stop_if_true + + c= Control(f=g) + @test c.f == g + @test !c.stop_if_true + + c= Control(f=g, stop_if_true=true) + @test c.f == g + @test c.stop_if_true + end +end