Skip to content

Commit

Permalink
Merge pull request #42 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.3.3 release
  • Loading branch information
ablaom authored Apr 19, 2021
2 parents ed77d90 + ce69263 commit e7a562e
Show file tree
Hide file tree
Showing 14 changed files with 264 additions and 117 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "IterationControl"
uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.3.2"
version = "0.3.3"

[deps]
EarlyStopping = "792122b4-ca99-40de-a6bc-6742525f08b6"
Expand Down
1 change: 1 addition & 0 deletions src/IterationControl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ include("composite_controls.jl")
include("wrapped_controls.jl")
include("controls.jl")
include("train.jl")
include("square_rooter.jl")

end # module
2 changes: 1 addition & 1 deletion src/composite_controls.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct CompositeControl{A,B}
struct CompositeControl{A,B}
a::A
b::B
function CompositeControl(a::A, b::B) where {A, B}
Expand Down
81 changes: 47 additions & 34 deletions src/controls.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# # TRAIN
# # Step

struct Step
n::Int
Expand All @@ -13,13 +13,19 @@ Step(; n=5) = Step(n)
body="Train for `n` more iterations. "*
"Will never trigger a stop. ")

function update!(c::Step, model, verbosity, args...)
if verbosity > 1
@info "Stepping model for $(c.n) iterations. "
else
nothing
end
function update!(c::Step, model, verbosity, state=(new_iterations = 0,))
new_iterations = state.new_iterations
verbosity > 1 &&
@info "Stepping model for $(c.n) more iterations. "
train!(model, c.n)
state = (new_iterations = new_iterations + c.n,)
return state
end

@inline function takedown(c::Step, verbosity, state)
verbosity > 1 &&
@info "A total of $(state.new_iterations) iterations added. "
return state
end

# # Info
Expand All @@ -34,7 +40,7 @@ Info(; f::Function=identity) = Info(f)
@create_docs(Info,
header="Info(f=identity)",
example="Info(my_loss_function)",
body="Log at the `Info` level the value of `f(m)`, "*
body="Log to `Info` the value of `f(m)`, "*
"where `m` "*
"is the object being iterated. If "*
"`IterativeControl.expose(m)` has been overloaded, then "*
Expand All @@ -44,12 +50,12 @@ Info(; f::Function=identity) = Info(f)
"See also [`Warn`](@ref), [`Error`](@ref). ")

function update!(c::Info, model, verbosity, args...)
verbosity < 1 || @info _log_eval(c.f, model)
verbosity > 0 && @info _log_eval(c.f, model)
return nothing
end


# # WARN
# # Warn

struct Warn{P<:Function,F<:Union{Function,String}}
predicate::P
Expand All @@ -64,35 +70,35 @@ Warn(predicate; f="") = Warn(predicate, f)
example="Warn(m -> length(m.cache) > 100, "*
"f=\"Memory low\")",
body="If `predicate(m)` is `true`, then "*
"log at the `Warn` level the value of `f` "*
"log to `Warn` the value of `f` "*
"(or `f(IterationControl.expose(m))` if `f` is a function). "*
"Here `m` "*
"is the object being iterated.\n\n"*
"Can be suppressed by setting the global verbosity level "*
"sufficiently low.\n\n"*
"See also [`Info`](@ref), [`Error`](@ref). ")

function update!(c::Warn, model, verbosity, args...)
verbosity > 1 && c.predicate(model) &&
@warn _log_eval(c.f, model)
return nothing
end

function update!(c::Warn, model, verbosity, warnings=())
function update!(c::Warn, model, verbosity, state=(warnings=(),))
warnings = state.warnings
if c.predicate(model)
warning = _log_eval(c.f, model)
verbosity < 1 || @warn warning
state = tuple(warnings..., warning)
verbosity > -1 && @warn warning
newstate = (warnings = tuple(warnings..., warning),)
else
state = warnings
newstate = state
end
return state
return newstate
end

takedown(c::Warn, verbosity, state) = (warnings = state,)
function takedown(c::Warn, verbosity, state)
warnings = join(state.warnings, "\n")
verbosity > 1 && !isempty(warnings) &&
@warn "A `Warn` control issued these warnings:\n$warnings"
return state
end


# # ERROR
# # Error

struct Error{P<:Function,F<:Union{Function,String}}
predicate::P
Expand Down Expand Up @@ -286,24 +292,28 @@ WithLossDo(; f=x->@info("loss: $x"), kwargs...) = WithLossDo(f, kwargs...)

EarlyStopping.needs_loss(::Type{<:WithLossDo}) = true

function update!(c::WithLossDo, model, verbosity, state=(done=false, ))
function update!(c::WithLossDo,
model,
verbosity,
state=(loss=nothing, done=false))
loss = IterationControl.loss(model)
r = c.f(loss)
done = (c.stop_if_true && r isa Bool && r) ? true : false
return (done=done,)
return (loss=loss, done=done)
end

done(c::WithLossDo, state) = state.done

function takedown(c::WithLossDo, verbosity, state)
verbosity > 1 && @info "final loss: $(state.loss). "
if state.done
message = c.stop_message === nothing ?
"Stop triggered by a `WithLossDo` control. " :
c.stop_message
verbosity > 0 && @info message
return (done = true, log = message)
return merge(state, (log = message,))
else
return (done = false, log = "")
return merge(state, (log = "",))
end
end

Expand Down Expand Up @@ -340,24 +350,26 @@ EarlyStopping.needs_training_losses(::Type{<:WithTrainingLossesDo}) = true
function update!(c::WithTrainingLossesDo,
model,
verbosity,
state=(done=false, ))
state=(latest_training_loss = nothing, done = false))
losses = IterationControl.training_losses(model)
r = c.f(losses)
done = (c.stop_if_true && r isa Bool && r) ? true : false
return (done=done, )
return (latest_training_loss=losses[end], done=done)
end

done(c::WithTrainingLossesDo, state) = state.done

function takedown(c::WithTrainingLossesDo, verbosity, state)
verbosity > 1 &&
@info "final training loss: $(state.latest_training_loss). "
if state.done
message = c.stop_message === nothing ?
"Stop triggered by a `WithTrainingLossesDo` control. " :
c.stop_message
verbosity > 0 && @info message
return (done = true, log = message)
return merge(state, (log = message,))
else
return (done = false, log = "")
return merge(state, (log = "",))
end
end

Expand Down Expand Up @@ -398,13 +410,14 @@ end
done(c::WithNumberDo, state) = state.done

function takedown(c::WithNumberDo, verbosity, state)
verbosity > 1 && @info "final number: $(state.n). "
if state.done
message = c.stop_message === nothing ?
"Stop triggered by a `WithNumberDo` control. " :
c.stop_message
verbosity > 0 && @info message
return (done = true, log = message)
return merge(state, (log = message,))
else
return (done = false, log = "")
return merge(state, (log = "",))
end
end
22 changes: 22 additions & 0 deletions src/square_rooter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# ## SQUARE ROOTER

# Consider a model to compute Babylonian approximations to a square root:

mutable struct SquareRooter
x::Float64 # input - number to be square rooted
root::Float64 # current approximation of root
training_losses::Vector{Float64} # successive approximation differences
SquareRooter(x) = new(x, 1.0, Float64[])
end

function IterationControl.train!(m::SquareRooter, Δn::Int)
m.training_losses = Float64[]
for i in 1:Δn
next_guess = (m.root + m.x/m.root)/2
push!(m.training_losses, abs(next_guess - m.root))
m.root = next_guess
end
end

IterationControl.loss(m::SquareRooter) = abs(m.root^2 - m.x)
IterationControl.training_losses(m::SquareRooter) = m.training_losses
9 changes: 9 additions & 0 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ function train!(model, controls...; verbosity::Int=1)
finished = done(control, state)
end

# reporting final loss and training loss if available:
loss = IterationControl.loss(model)
training_losses = IterationControl.training_losses(model)
if verbosity > 0
loss isa Nothing || @info "final loss: $loss"
training_losses isa Nothing ||
@info "final training loss: $(training_losses[end])"
end

# finalization:
return takedown(control, verbosity, state)
end
25 changes: 25 additions & 0 deletions src/wrapped_controls.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
# # Louder

struct Louder{C}
control::C
by::Int64
end

"""
IterationControl.louder(control, by=1)
Wrap `control` to make in more (or less) verbose. The same as
`control`, but as if the global `verbosity` were increased by the value
`by`.
"""
louder(c; by=1) = Louder(c, by)

# api:
done(d::Louder, state) = done(d.control, state)
update!(d::Louder, model, verbosity, args...) =
update!(d.control, model, verbosity + d.by, args...)
takedown(d::Louder, verbosity, state) =
takedown(d.control, verbosity + d.by, state)


# # Debug

struct Debug{C}
Expand Down
58 changes: 13 additions & 45 deletions test/_models_for_testing.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,5 @@
# # DUMMY MODELS FOR TESTING


# ## SQUARE ROOTER

# Consider a model to compute Babylonian approximations to a square root:

mutable struct SquareRooter
x::Float64 # input - number to be square rooted
root::Float64 # current approximation of root
training_losses::Vector{Float64} # successive approximation differences
SquareRooter(x) = new(x, 1.0, Float64[])
end

function IterationControl.train!(m::SquareRooter, Δn::Int)
m.training_losses = Float64[]
for i in 1:Δn
next_guess = (m.root + m.x/m.root)/2
push!(m.training_losses, abs(next_guess - m.root))
m.root = next_guess
end
end

IterationControl.loss(m::SquareRooter) = abs(m.root^2 - m.x)
IterationControl.training_losses(m::SquareRooter) = m.training_losses

model = SquareRooter(4.0)
IterationControl.train!(model, 1)
@assert model.root 2.5
@assert IterationControl.loss(model) 25/4 - 4
IterationControl.train!(model, 100)
@assert IterationControl.loss(model) 0
@assert IterationControl.training_losses(model)[1:2]
abs.([41/20 - 5/2, 3281/1640 - 41/20])


# ## PARTICLE TRACKER (without methods lifted)

# Consider an object that tracks a particle in one dimension, moving,
Expand Down Expand Up @@ -80,15 +46,17 @@ function ingest!(model::Particle, target)
return nothing
end

model = Particle()
ingest!(model, 1)
train!(model, 1)
@assert loss(model) 0.9
ingest!(model, -0.9)
train!(model, 1)
@assert loss(model) 0.9
# temporary testing area

# model = Particle()
# ingest!(model, 1)
# train!(model, 1)
# @assert loss(model) ≈ 0.9
# ingest!(model, -0.9)
# train!(model, 1)
# @assert loss(model) ≈ 0.9

model = Particle()
ingest!(model, 1)
train!(model, 2)
@assert training_losses(model) [0.9, 0.81]
# model = Particle()
# ingest!(model, 1)
# train!(model, 2)
# @assert training_losses(model) ≈ [0.9, 0.81]
2 changes: 1 addition & 1 deletion test/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model = Particle()
invalid = InvalidValue()

@test_throws IC.ERR_TRAIN IterationControl.train!(model)
@test_throws IC.err_train(model) IterationControl.train!(model, 1)
@test_throws IC.err_train(model) IterationControl.train!(model, 1)

# lifting train!:
IC.train!(model::Particle, n) = train!(model, n)
Expand Down
Loading

0 comments on commit e7a562e

Please sign in to comment.