Skip to content

Commit

Permalink
Reopen #390: Update callback (#440)
Browse files Browse the repository at this point in the history
* add `UpdateCallback`

* fix typo

* apply formatter

* remove update bool

* adapt docstring

* implement suggestions

* remove redundant comment
  • Loading branch information
LasNikas authored May 6, 2024
1 parent 1ccfa7d commit 3445bd0
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/TrixiParticles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export InitialCondition
export WeaklyCompressibleSPHSystem, EntropicallyDampedSPHSystem, TotalLagrangianSPHSystem,
BoundarySPHSystem, DEMSystem, BoundaryDEMSystem
export InfoCallback, SolutionSavingCallback, DensityReinitializationCallback,
PostprocessCallback, StepsizeCallback
PostprocessCallback, StepsizeCallback, UpdateCallback
export ContinuityDensity, SummationDensity
export PenaltyForceGanzenmueller
export SchoenbergCubicSplineKernel, SchoenbergQuarticSplineKernel,
Expand Down
1 change: 1 addition & 0 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ include("solution_saving.jl")
include("density_reinit.jl")
include("post_process.jl")
include("stepsize.jl")
include("update.jl")
130 changes: 130 additions & 0 deletions src/callbacks/update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
struct UpdateCallback{I}
interval::I
end

"""
UpdateCallback(; interval::Integer, dt=0.0)
Callback to update quantities either at the end of every `interval` time steps or
in intervals of `dt` in terms of integration time by adding additional `tstops`
(note that this may change the solution).
# Keywords
- `interval=1`: Update quantities at the end of every `interval` time steps.
- `dt`: Update quantities in regular intervals of `dt` in terms of integration time
by adding additional `tstops` (note that this may change the solution).
"""
function UpdateCallback(; interval::Integer=-1, dt=0.0)
if dt > 0 && interval !== -1
throw(ArgumentError("Setting both interval and dt is not supported!"))
end

# Update in intervals in terms of simulation time
if dt > 0
interval = Float64(dt)

# Update every time step (default)
elseif interval == -1
interval = 1
end

update_callback! = UpdateCallback(interval)

if dt > 0
# Add a `tstop` every `dt`, and save the final solution.
return PeriodicCallback(update_callback!, dt,
initialize=initial_update!,
save_positions=(false, false))
else
# The first one is the `condition`, the second the `affect!`
return DiscreteCallback(update_callback!, update_callback!,
initialize=initial_update!,
save_positions=(false, false))
end
end

# `initialize`
function initial_update!(cb, u, t, integrator)
# The `UpdateCallback` is either `cb.affect!` (with `DiscreteCallback`)
# or `cb.affect!.affect!` (with `PeriodicCallback`).
# Let recursive dispatch handle this.

initial_update!(cb.affect!, u, t, integrator)
end

initial_update!(cb::UpdateCallback, u, t, integrator) = cb(integrator)

# `condition`
function (update_callback!::UpdateCallback)(u, t, integrator)
(; interval) = update_callback!

return condition_integrator_interval(integrator, interval)
end

# `affect!`
function (update_callback!::UpdateCallback)(integrator)
t = integrator.t
semi = integrator.p
v_ode, u_ode = integrator.u.x

# Update quantities that are stored in the systems. These quantities (e.g. pressure)
# still have the values from the last stage of the previous step if not updated here.
update_systems_and_nhs(v_ode, u_ode, semi, t)

# Other updates might be added here later (e.g. Transport Velocity Formulation).
# @trixi_timeit timer() "update open boundary" foreach_system(semi) do system
# update_open_boundary_eachstep!(system, v_ode, u_ode, semi, t)
# end
#
# @trixi_timeit timer() "update TVF" foreach_system(semi) do system
# update_transport_velocity_eachstep!(system, v_ode, u_ode, semi, t)
# end

# Tell OrdinaryDiffEq that `u` has been modified
u_modified!(integrator, true)

return integrator
end

function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:UpdateCallback})
@nospecialize cb # reduce precompilation time
print(io, "UpdateCallback(interval=", cb.affect!.interval, ")")
end

function Base.show(io::IO,
cb::DiscreteCallback{<:Any,
<:PeriodicCallbackAffect{<:UpdateCallback}})
@nospecialize cb # reduce precompilation time
print(io, "UpdateCallback(dt=", cb.affect!.affect!.interval, ")")
end

function Base.show(io::IO, ::MIME"text/plain",
cb::DiscreteCallback{<:Any, <:UpdateCallback})
@nospecialize cb # reduce precompilation time

if get(io, :compact, false)
show(io, cb)
else
update_cb = cb.affect!
setup = [
"interval" => update_cb.interval,
]
summary_box(io, "UpdateCallback", setup)
end
end

function Base.show(io::IO, ::MIME"text/plain",
cb::DiscreteCallback{<:Any,
<:PeriodicCallbackAffect{<:UpdateCallback}})
@nospecialize cb # reduce precompilation time

if get(io, :compact, false)
show(io, cb)
else
update_cb = cb.affect!.affect!
setup = [
"dt" => update_cb.interval,
]
summary_box(io, "UpdateCallback", setup)
end
end
1 change: 1 addition & 0 deletions test/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
include("info.jl")
include("stepsize.jl")
include("postprocess.jl")
include("update.jl")
include("solution_saving.jl")
end
48 changes: 48 additions & 0 deletions test/callbacks/update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
@testset verbose=true "UpdateCallback" begin
@testset verbose=true "show" begin
# Default
callback0 = UpdateCallback()

show_compact = "UpdateCallback(interval=1)"
@test repr(callback0) == show_compact

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ UpdateCallback │
│ ══════════════ │
│ interval: ……………………………………………………… 1 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback0) == show_box

callback1 = UpdateCallback(interval=11)

show_compact = "UpdateCallback(interval=11)"
@test repr(callback1) == show_compact

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ UpdateCallback │
│ ══════════════ │
│ interval: ……………………………………………………… 11 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback1) == show_box

callback2 = UpdateCallback(dt=1.2)

show_compact = "UpdateCallback(dt=1.2)"
@test repr(callback2) == show_compact

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ UpdateCallback │
│ ══════════════ │
│ dt: ……………………………………………………………………… 1.2 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback2) == show_box
end

@testset "Illegal Input" begin
error_str = "Setting both interval and dt is not supported!"
@test_throws ArgumentError(error_str) UpdateCallback(dt=0.1, interval=1)
end
end

0 comments on commit 3445bd0

Please sign in to comment.