diff --git a/Project.toml b/Project.toml index 9dfd9ed0d3..958772cb05 100644 --- a/Project.toml +++ b/Project.toml @@ -121,7 +121,7 @@ REPL = "1" RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" -SciMLBase = "2.56.1" +SciMLBase = "2.57.1" SciMLStructures = "1.0" Serialization = "1" Setfield = "0.7, 0.8, 1" diff --git a/ext/MTKHomotopyContinuationExt.jl b/ext/MTKHomotopyContinuationExt.jl index 81d9a840c9..e4a9243ff9 100644 --- a/ext/MTKHomotopyContinuationExt.jl +++ b/ext/MTKHomotopyContinuationExt.jl @@ -2,7 +2,7 @@ module MTKHomotopyContinuationExt using ModelingToolkit using ModelingToolkit.SciMLBase -using ModelingToolkit.Symbolics: unwrap, symtype +using ModelingToolkit.Symbolics: unwrap, symtype, BasicSymbolic, simplify_fractions using ModelingToolkit.SymbolicIndexingInterface using ModelingToolkit.DocStringExtensions using HomotopyContinuation @@ -27,7 +27,7 @@ function is_polynomial(x, wrt) contains_variable(x, wrt) || return true any(isequal(x), wrt) && return true - if operation(x) in (*, +, -) + if operation(x) in (*, +, -, /) return all(y -> is_polynomial(y, wrt), arguments(x)) end if operation(x) == (^) @@ -57,6 +57,57 @@ end """ $(TYPEDSIGNATURES) +Given a `x`, a polynomial in variables in `wrt` which may contain rational functions, +express `x` as a single rational function with polynomial `num` and denominator `den`. +Return `(num, den)`. +""" +function handle_rational_polynomials(x, wrt) + x = unwrap(x) + symbolic_type(x) == NotSymbolic() && return x, 1 + iscall(x) || return x, 1 + contains_variable(x, wrt) || return x, 1 + any(isequal(x), wrt) && return x, 1 + + # simplify_fractions cancels out some common factors + # and expands (a / b)^c to a^c / b^c, so we only need + # to handle these cases + x = simplify_fractions(x) + op = operation(x) + args = arguments(x) + + if op == / + # numerator and denominator are trivial + num, den = args + # but also search for rational functions in numerator + n, d = handle_rational_polynomials(num, wrt) + num, den = n, den * d + elseif op == + + num = 0 + den = 1 + + # we don't need to do common denominator + # because we don't care about cases where denominator + # is zero. The expression is zero when all the numerators + # are zero. + for arg in args + n, d = handle_rational_polynomials(arg, wrt) + num += n + den *= d + end + else + return x, 1 + end + # if the denominator isn't a polynomial in `wrt`, better to not include it + # to reduce the size of the gcd polynomial + if !contains_variable(den, wrt) + return num / den, 1 + end + return num, den +end + +""" +$(TYPEDSIGNATURES) + Convert `expr` from a symbolics expression to one that uses `HomotopyContinuation.ModelKit`. """ function symbolics_to_hc(expr) @@ -139,51 +190,74 @@ function MTK.HomotopyContinuationProblem( dvs = unknowns(sys) eqs = equations(sys) - for eq in eqs + denoms = [] + eqs2 = map(eqs) do eq if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs) error("Equation $eq is not a polynomial in the unknowns. See warnings for further details.") end + num, den = handle_rational_polynomials(eq.rhs - eq.lhs, dvs) + push!(denoms, den) + return 0 ~ num end - nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys, u0map, parammap; + sys2 = MTK.@set sys.eqs = eqs2 + + nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys2, u0map, parammap; jac = true, eval_expression, eval_module) + denominator = MTK.build_explicit_observed_function(sys, denoms) + hvars = symbolics_to_hc.(dvs) mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs)) obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module) - return MTK.HomotopyContinuationProblem(u0, mtkhsys, sys, obsfn) + return MTK.HomotopyContinuationProblem(u0, mtkhsys, denominator, sys, obsfn) end """ $(TYPEDSIGNATURES) Solve a `HomotopyContinuationProblem`. Ignores the algorithm passed to it, and always -uses `HomotopyContinuation.jl`. All keyword arguments are forwarded to -`HomotopyContinuation.solve`. The original solution as returned by `HomotopyContinuation.jl` -will be available in the `.original` field of the returned `NonlinearSolution`. +uses `HomotopyContinuation.jl`. All keyword arguments except the ones listed below are +forwarded to `HomotopyContinuation.solve`. The original solution as returned by +`HomotopyContinuation.jl` will be available in the `.original` field of the returned +`NonlinearSolution`. All keyword arguments have their default values in HomotopyContinuation.jl, except `show_progress` which defaults to `false`. + +Extra keyword arguments: +- `denominator_abstol`: In case `prob` is solving a rational function, roots which cause + the denominator to be below `denominator_abstol` will be discarded. """ function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem, - alg = nothing; show_progress = false, kwargs...) + alg = nothing; show_progress = false, denominator_abstol = 1e-8, kwargs...) sol = HomotopyContinuation.solve( prob.homotopy_continuation_system; show_progress, kwargs...) realsols = HomotopyContinuation.results(sol; only_real = true) if isempty(realsols) u = state_values(prob) - resid = prob.homotopy_continuation_system(u) retcode = SciMLBase.ReturnCode.ConvergenceFailure else + T = eltype(state_values(prob)) distance, idx = findmin(realsols) do result + if any(<=(denominator_abstol), + prob.denominator(real.(result.solution), parameter_values(prob))) + return T(Inf) + end norm(result.solution - state_values(prob)) end - u = real.(realsols[idx].solution) - resid = prob.homotopy_continuation_system(u) - retcode = SciMLBase.ReturnCode.Success + # all roots cause denominator to be zero + if isinf(distance) + u = state_values(prob) + retcode = SciMLBase.ReturnCode.Infeasible + else + u = real.(realsols[idx].solution) + retcode = SciMLBase.ReturnCode.Success + end end + resid = prob.homotopy_continuation_system(u) return SciMLBase.build_solution( prob, :HomotopyContinuation, u, resid; retcode, original = sol) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 01a250d46a..bdad40d52f 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -573,7 +573,7 @@ A type of Nonlinear problem which specializes on polynomial systems and uses HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to create and solve. """ -struct HomotopyContinuationProblem{uType, H, O} <: +struct HomotopyContinuationProblem{uType, H, D, O} <: SciMLBase.AbstractNonlinearProblem{uType, true} """ The initial values of states in the system. If there are multiple real roots of @@ -586,6 +586,12 @@ struct HomotopyContinuationProblem{uType, H, O} <: """ homotopy_continuation_system::H """ + A function with signature `(u, p) -> resid`. In case of rational functions, this + is used to rule out roots of the system which would cause the denominator to be + zero. + """ + denominator::D + """ The `NonlinearSystem` used to create this problem. Used for symbolic indexing. """ sys::NonlinearSystem diff --git a/test/extensions/homotopy_continuation.jl b/test/extensions/homotopy_continuation.jl index ceabb6b6e3..6d7279899f 100644 --- a/test/extensions/homotopy_continuation.jl +++ b/test/extensions/homotopy_continuation.jl @@ -82,7 +82,47 @@ end @mtkbuild sys = NonlinearSystem([x^x - x ~ 0]) @test_warn ["Exponent", "unknowns"] @test_throws "not a polynomial" HomotopyContinuationProblem( sys, []) - @mtkbuild sys = NonlinearSystem([((x^2) / (x + 3))^2 + x ~ 0]) - @test_warn ["Base", "not a polynomial", "Unrecognized operation", "/"] @test_throws "not a polynomial" HomotopyContinuationProblem( + @mtkbuild sys = NonlinearSystem([((x^2) / sin(x))^2 + x ~ 0]) + @test_warn ["Unrecognized", "sin"] @test_throws "not a polynomial" HomotopyContinuationProblem( sys, []) end + +@testset "Rational functions" begin + @variables x=2.0 y=2.0 + @parameters n = 4 + @mtkbuild sys = NonlinearSystem([ + 0 ~ (x^2 - n * x + n) * (x - 1) / (x - 2) / (x - 3) + ]) + prob = HomotopyContinuationProblem(sys, []) + sol = solve(prob; threading = false) + @test sol[x] ≈ 1.0 + p = parameter_values(prob) + for invalid in [2.0, 3.0] + @test prob.denominator([invalid], p)[1] <= 1e-8 + end + + @named sys = NonlinearSystem( + [ + 0 ~ (x - 2) / (x - 4) + ((x - 3) / (y - 7)) / ((x^2 - 4x + y) / (x - 2.5)), + 0 ~ ((y - 3) / (y - 4)) * (n / (y - 5)) + ((x - 1.5) / (x - 5.5))^2 + ], + [x, y], + [n]) + sys = complete(sys) + prob = HomotopyContinuationProblem(sys, []) + sol = solve(prob; threading = false) + disallowed_x = [4, 5.5] + disallowed_y = [7, 5, 4] + @test all(!isapprox(sol[x]; atol = 1e-8), disallowed_x) + @test all(!isapprox(sol[y]; atol = 1e-8), disallowed_y) + @test sol[x^2 - 4x + y] >= 1e-8 + + p = parameter_values(prob) + for val in disallowed_x + @test any(<=(1e-8), prob.denominator([val, 2.0], p)) + end + for val in disallowed_y + @test any(<=(1e-8), prob.denominator([2.0, val], p)) + end + @test prob.denominator([2.0, 4.0], p)[1] <= 1e-8 +end