Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support rational functions in HomotopyContinuationProblem #3151

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ REPL = "1"
RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.56.1"
SciMLBase = "2.57.1"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
Expand Down
100 changes: 87 additions & 13 deletions ext/MTKHomotopyContinuationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module MTKHomotopyContinuationExt

using ModelingToolkit
using ModelingToolkit.SciMLBase
using ModelingToolkit.Symbolics: unwrap, symtype
using ModelingToolkit.Symbolics: unwrap, symtype, BasicSymbolic, simplify_fractions
using ModelingToolkit.SymbolicIndexingInterface
using ModelingToolkit.DocStringExtensions
using HomotopyContinuation
Expand All @@ -27,7 +27,7 @@ function is_polynomial(x, wrt)
contains_variable(x, wrt) || return true
any(isequal(x), wrt) && return true

if operation(x) in (*, +, -)
if operation(x) in (*, +, -, /)
return all(y -> is_polynomial(y, wrt), arguments(x))
end
if operation(x) == (^)
Expand Down Expand Up @@ -57,6 +57,57 @@ end
"""
$(TYPEDSIGNATURES)

Given a `x`, a polynomial in variables in `wrt` which may contain rational functions,
express `x` as a single rational function with polynomial `num` and denominator `den`.
Return `(num, den)`.
"""
function handle_rational_polynomials(x, wrt)
x = unwrap(x)
symbolic_type(x) == NotSymbolic() && return x, 1
iscall(x) || return x, 1
contains_variable(x, wrt) || return x, 1
any(isequal(x), wrt) && return x, 1

# simplify_fractions cancels out some common factors
# and expands (a / b)^c to a^c / b^c, so we only need
# to handle these cases
x = simplify_fractions(x)
op = operation(x)
args = arguments(x)

if op == /
# numerator and denominator are trivial
num, den = args
# but also search for rational functions in numerator
n, d = handle_rational_polynomials(num, wrt)
num, den = n, den * d
elseif op == +
num = 0
den = 1

# we don't need to do common denominator
# because we don't care about cases where denominator
# is zero. The expression is zero when all the numerators
# are zero.
for arg in args
n, d = handle_rational_polynomials(arg, wrt)
num += n
den *= d
end
else
return x, 1
end
# if the denominator isn't a polynomial in `wrt`, better to not include it
# to reduce the size of the gcd polynomial
if !contains_variable(den, wrt)
return num / den, 1
end
return num, den
end

"""
$(TYPEDSIGNATURES)

Convert `expr` from a symbolics expression to one that uses `HomotopyContinuation.ModelKit`.
"""
function symbolics_to_hc(expr)
Expand Down Expand Up @@ -139,51 +190,74 @@ function MTK.HomotopyContinuationProblem(
dvs = unknowns(sys)
eqs = equations(sys)

for eq in eqs
denoms = []
eqs2 = map(eqs) do eq
if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs)
error("Equation $eq is not a polynomial in the unknowns. See warnings for further details.")
end
num, den = handle_rational_polynomials(eq.rhs - eq.lhs, dvs)
push!(denoms, den)
return 0 ~ num
end

nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys, u0map, parammap;
sys2 = MTK.@set sys.eqs = eqs2

nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys2, u0map, parammap;
jac = true, eval_expression, eval_module)

denominator = MTK.build_explicit_observed_function(sys, denoms)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't you just want the numerator?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The denominator is used to rule out roots that cause a 0/0. A trivial example would be 0 ~ (x - y) / (x - 3), 0 ~ y - 3. If we take the numerators and solve, we get x = y = 3, but that causes 0 / 0 in the first polynomial and hence the system is actually infeasible.


hvars = symbolics_to_hc.(dvs)
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))

obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)

return MTK.HomotopyContinuationProblem(u0, mtkhsys, sys, obsfn)
return MTK.HomotopyContinuationProblem(u0, mtkhsys, denominator, sys, obsfn)
end

"""
$(TYPEDSIGNATURES)

Solve a `HomotopyContinuationProblem`. Ignores the algorithm passed to it, and always
uses `HomotopyContinuation.jl`. All keyword arguments are forwarded to
`HomotopyContinuation.solve`. The original solution as returned by `HomotopyContinuation.jl`
will be available in the `.original` field of the returned `NonlinearSolution`.
uses `HomotopyContinuation.jl`. All keyword arguments except the ones listed below are
forwarded to `HomotopyContinuation.solve`. The original solution as returned by
`HomotopyContinuation.jl` will be available in the `.original` field of the returned
`NonlinearSolution`.

All keyword arguments have their default values in HomotopyContinuation.jl, except
`show_progress` which defaults to `false`.

Extra keyword arguments:
- `denominator_abstol`: In case `prob` is solving a rational function, roots which cause
the denominator to be below `denominator_abstol` will be discarded.
"""
function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
alg = nothing; show_progress = false, kwargs...)
alg = nothing; show_progress = false, denominator_abstol = 1e-8, kwargs...)
sol = HomotopyContinuation.solve(
prob.homotopy_continuation_system; show_progress, kwargs...)
realsols = HomotopyContinuation.results(sol; only_real = true)
if isempty(realsols)
u = state_values(prob)
resid = prob.homotopy_continuation_system(u)
retcode = SciMLBase.ReturnCode.ConvergenceFailure
else
T = eltype(state_values(prob))
distance, idx = findmin(realsols) do result
if any(<=(denominator_abstol),
prob.denominator(real.(result.solution), parameter_values(prob)))
return T(Inf)
end
norm(result.solution - state_values(prob))
end
u = real.(realsols[idx].solution)
resid = prob.homotopy_continuation_system(u)
retcode = SciMLBase.ReturnCode.Success
# all roots cause denominator to be zero
if isinf(distance)
u = state_values(prob)
retcode = SciMLBase.ReturnCode.Infeasible
else
u = real.(realsols[idx].solution)
retcode = SciMLBase.ReturnCode.Success
end
end
resid = prob.homotopy_continuation_system(u)

return SciMLBase.build_solution(
prob, :HomotopyContinuation, u, resid; retcode, original = sol)
Expand Down
8 changes: 7 additions & 1 deletion src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ A type of Nonlinear problem which specializes on polynomial systems and uses
HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to
create and solve.
"""
struct HomotopyContinuationProblem{uType, H, O} <:
struct HomotopyContinuationProblem{uType, H, D, O} <:
SciMLBase.AbstractNonlinearProblem{uType, true}
"""
The initial values of states in the system. If there are multiple real roots of
Expand All @@ -586,6 +586,12 @@ struct HomotopyContinuationProblem{uType, H, O} <:
"""
homotopy_continuation_system::H
"""
A function with signature `(u, p) -> resid`. In case of rational functions, this
is used to rule out roots of the system which would cause the denominator to be
zero.
"""
denominator::D
"""
The `NonlinearSystem` used to create this problem. Used for symbolic indexing.
"""
sys::NonlinearSystem
Expand Down
44 changes: 42 additions & 2 deletions test/extensions/homotopy_continuation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,47 @@ end
@mtkbuild sys = NonlinearSystem([x^x - x ~ 0])
@test_warn ["Exponent", "unknowns"] @test_throws "not a polynomial" HomotopyContinuationProblem(
sys, [])
@mtkbuild sys = NonlinearSystem([((x^2) / (x + 3))^2 + x ~ 0])
@test_warn ["Base", "not a polynomial", "Unrecognized operation", "/"] @test_throws "not a polynomial" HomotopyContinuationProblem(
@mtkbuild sys = NonlinearSystem([((x^2) / sin(x))^2 + x ~ 0])
@test_warn ["Unrecognized", "sin"] @test_throws "not a polynomial" HomotopyContinuationProblem(
sys, [])
end

@testset "Rational functions" begin
@variables x=2.0 y=2.0
@parameters n = 4
@mtkbuild sys = NonlinearSystem([
0 ~ (x^2 - n * x + n) * (x - 1) / (x - 2) / (x - 3)
])
prob = HomotopyContinuationProblem(sys, [])
sol = solve(prob; threading = false)
@test sol[x] ≈ 1.0
p = parameter_values(prob)
for invalid in [2.0, 3.0]
@test prob.denominator([invalid], p)[1] <= 1e-8
end

@named sys = NonlinearSystem(
[
0 ~ (x - 2) / (x - 4) + ((x - 3) / (y - 7)) / ((x^2 - 4x + y) / (x - 2.5)),
0 ~ ((y - 3) / (y - 4)) * (n / (y - 5)) + ((x - 1.5) / (x - 5.5))^2
],
[x, y],
[n])
sys = complete(sys)
prob = HomotopyContinuationProblem(sys, [])
sol = solve(prob; threading = false)
disallowed_x = [4, 5.5]
disallowed_y = [7, 5, 4]
@test all(!isapprox(sol[x]; atol = 1e-8), disallowed_x)
@test all(!isapprox(sol[y]; atol = 1e-8), disallowed_y)
@test sol[x^2 - 4x + y] >= 1e-8

p = parameter_values(prob)
for val in disallowed_x
@test any(<=(1e-8), prob.denominator([val, 2.0], p))
end
for val in disallowed_y
@test any(<=(1e-8), prob.denominator([2.0, val], p))
end
@test prob.denominator([2.0, 4.0], p)[1] <= 1e-8
end
Loading