diff --git a/dwave/optimization/_model.pxd b/dwave/optimization/_model.pxd index 9585d7c1..9f03ccea 100644 --- a/dwave/optimization/_model.pxd +++ b/dwave/optimization/_model.pxd @@ -28,6 +28,7 @@ cdef class _Graph: cpdef bool is_locked(self) noexcept cpdef Py_ssize_t num_constraints(self) noexcept cpdef Py_ssize_t num_decisions(self) noexcept + cpdef Py_ssize_t num_inputs(self) noexcept cpdef Py_ssize_t num_nodes(self) noexcept cpdef Py_ssize_t num_symbols(self) noexcept diff --git a/dwave/optimization/_model.pyi b/dwave/optimization/_model.pyi index 7e9d0f57..e2b0965e 100644 --- a/dwave/optimization/_model.pyi +++ b/dwave/optimization/_model.pyi @@ -53,6 +53,7 @@ class _Graph: def is_locked(self) -> bool: ... def iter_constraints(self) -> collections.abc.Iterator[ArraySymbol]: ... def iter_decisions(self) -> collections.abc.Iterator[Symbol]: ... + def iter_inputs(self) -> collections.abc.Iterator[Symbol]: ... def iter_symbols(self) -> collections.abc.Iterator[Symbol]: ... def lock(self): ... def minimize(self, value: ArraySymbol): ... diff --git a/dwave/optimization/_model.pyx b/dwave/optimization/_model.pyx index 3ffc592d..0efb14a9 100644 --- a/dwave/optimization/_model.pyx +++ b/dwave/optimization/_model.pyx @@ -31,7 +31,8 @@ from libcpp.utility cimport move from libcpp.vector cimport vector from dwave.optimization.libcpp.array cimport Array as cppArray -from dwave.optimization.libcpp.graph cimport DecisionNode as cppDecisionNode +from dwave.optimization.libcpp.graph cimport DecisionNode as cppDecisionNode, Node as cppNode +from dwave.optimization.libcpp.nodes cimport InputNode as cppInputNode from dwave.optimization.states cimport States from dwave.optimization.states import StateView from dwave.optimization.symbols cimport symbol_from_ptr @@ -402,6 +403,10 @@ cdef class _Graph: for ptr in self._graph.decisions(): yield symbol_from_ptr(self, ptr) + def iter_inputs(self): + for ptr in self._graph.inputs(): + yield symbol_from_ptr(self, ptr) + def iter_symbols(self): """Iterate over all symbols in the model. @@ -522,6 +527,9 @@ cdef class _Graph: num_edges += self._graph.nodes()[i].get().successors().size() return num_edges + cpdef Py_ssize_t num_inputs(self) noexcept: + return self._graph.num_inputs() + cpdef Py_ssize_t num_nodes(self) noexcept: """Number of nodes in the directed acyclic graph for the model. diff --git a/dwave/optimization/include/dwave-optimization/array.hpp b/dwave/optimization/include/dwave-optimization/array.hpp index 73d3462f..824a29c1 100644 --- a/dwave/optimization/include/dwave-optimization/array.hpp +++ b/dwave/optimization/include/dwave-optimization/array.hpp @@ -935,6 +935,10 @@ std::ostream& operator<<(std::ostream& os, const Array::View& view); bool array_shape_equal(const Array* lhs_ptr, const Array* rhs_ptr); bool array_shape_equal(const Array& lhs, const Array& rhs); +// Test whether multiple arrays all have the same shape. +bool array_shape_equal(const std::span array_ptrs); +bool array_shape_equal(const std::vector& array_ptrs); + /// Get the shape induced by broadcasting two arrays together. /// See https://numpy.org/doc/stable/user/basics.broadcasting.html. /// Raises an exception if the two arrays cannot be broadcast together diff --git a/dwave/optimization/include/dwave-optimization/graph.hpp b/dwave/optimization/include/dwave-optimization/graph.hpp index 02b9b8b9..83ae1827 100644 --- a/dwave/optimization/include/dwave-optimization/graph.hpp +++ b/dwave/optimization/include/dwave-optimization/graph.hpp @@ -34,6 +34,7 @@ namespace dwave::optimization { class ArrayNode; class Node; class DecisionNode; +class InputNode; // We don't want this interface to be opinionated about what type of rng we're using. // So we create this class to do type erasure on RNGs. @@ -73,6 +74,7 @@ class Graph { public: Graph(); ~Graph(); + Graph(Graph&&); template NodeType* emplace_node(Args&&... args); @@ -138,6 +140,9 @@ class Graph { // The number of constraints in the model. ssize_t num_constraints() const noexcept { return constraints_.size(); } + // The number of input nodes in the model. + ssize_t num_inputs() const noexcept { return inputs_.size(); } + // Specify the objective node. Must be an array with a single element. // To unset the objective provide nullptr. void set_objective(ArrayNode* objective_ptr); @@ -158,6 +163,9 @@ class Graph { std::span decisions() noexcept { return decisions_; } std::span decisions() const noexcept { return decisions_; } + std::span inputs() noexcept { return inputs_; } + std::span inputs() const noexcept { return inputs_; } + // Remove unused nodes from the graph. // // This method will reset the topological sort if there is one. @@ -181,6 +189,7 @@ class Graph { ArrayNode* objective_ptr_ = nullptr; std::vector constraints_; std::vector decisions_; + std::vector inputs_; // Track whether the model is currently topologically sorted bool topologically_sorted_ = false; @@ -331,6 +340,8 @@ NodeType* Graph::emplace_node(Args&&... args) { static_assert(std::is_base_of_v); ptr->topological_index_ = decisions_.size(); decisions_.emplace_back(ptr); + } else if constexpr (std::is_base_of_v) { + inputs_.emplace_back(ptr); } return ptr; // return the observing pointer diff --git a/dwave/optimization/include/dwave-optimization/nodes.hpp b/dwave/optimization/include/dwave-optimization/nodes.hpp index 28677434..0a833c0b 100644 --- a/dwave/optimization/include/dwave-optimization/nodes.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes.hpp @@ -18,6 +18,7 @@ #include "dwave-optimization/nodes/constants.hpp" #include "dwave-optimization/nodes/flow.hpp" #include "dwave-optimization/nodes/indexing.hpp" +#include "dwave-optimization/nodes/lambda.hpp" #include "dwave-optimization/nodes/manipulation.hpp" #include "dwave-optimization/nodes/mathematical.hpp" #include "dwave-optimization/nodes/numbers.hpp" diff --git a/dwave/optimization/include/dwave-optimization/nodes/constants.hpp b/dwave/optimization/include/dwave-optimization/nodes/constants.hpp index eda86a22..e09c4298 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/constants.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/constants.hpp @@ -100,6 +100,17 @@ class ConstantNode : public ArrayOutputMixin { void commit(State&) const noexcept override {} void revert(State&) const noexcept override {} + protected: + // Information about the values in the buffer + struct BufferStats { + BufferStats() = delete; + explicit BufferStats(std::span buffer); + + bool integral; + double min; + double max; + }; + private: // Allocate the memory to hold shape worth of doubles, but don't populate it explicit ConstantNode(std::initializer_list shape) @@ -118,15 +129,6 @@ class ConstantNode : public ArrayOutputMixin { // holds its values on the object itself rather than in a State. double* buffer_ptr_; - // Information about the values in the buffer - struct BufferStats { - BufferStats() = delete; - explicit BufferStats(std::span buffer); - - bool integral; - double min; - double max; - }; mutable std::optional buffer_stats_; }; diff --git a/dwave/optimization/include/dwave-optimization/nodes/lambda.hpp b/dwave/optimization/include/dwave-optimization/nodes/lambda.hpp new file mode 100644 index 00000000..fa029d7d --- /dev/null +++ b/dwave/optimization/include/dwave-optimization/nodes/lambda.hpp @@ -0,0 +1,109 @@ +// Copyright 2023 D-Wave Systems Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "dwave-optimization/array.hpp" +#include "dwave-optimization/graph.hpp" + +namespace dwave::optimization { + +// InputNode acts like a placeholder or store of data very similar to ConstantNode, +// with the key different being that its contents *may* change in between propagations. +// However, it is not a decision variable--instead its use cases are acting as an "input" +// for "models as functions", or for placeholders in large models where (otherwise constant) +// data changes infrequently (e.g. a scheduling problem with a preference matrix). +// +// Currently there is no "default" way to initialize the state, so its must be initialized +// explicitly with some data. +class InputNode : public ArrayOutputMixin { + public: + explicit InputNode(std::span shape, double min, double max, bool integral) + : ArrayOutputMixin(shape), min_(min), max_(max), integral_(integral) {}; + + explicit InputNode(std::initializer_list shape, double min, double max, bool integral) + : ArrayOutputMixin(shape), min_(min), max_(max), integral_(integral) {}; + + explicit InputNode() + : InputNode({}, -std::numeric_limits::infinity(), + std::numeric_limits::infinity(), false) {}; + + bool integral() const override { return integral_; }; + + double max() const override { return max_; }; + double min() const override { return min_; }; + + void initialize_state(State& state) const override { + throw std::logic_error( + "InputNode must have state explicity initialized (with `initialize_state(state, " + "data)`)"); + } + + void initialize_state(State& state, std::span data) const; + + double const* buff(const State&) const override; + + std::span diff(const State& state) const noexcept override; + + void propagate(State& state) const noexcept override {}; + void commit(State& state) const noexcept override; + void revert(State& state) const noexcept override; + + void assign(State& state, const std::vector& new_values) const; + void assign(State& state, std::span new_values) const; + + private: + double min_, max_; + bool integral_; +}; + +class NaryReduceNode : public ArrayOutputMixin { + public: + // Runtime constructor that can be used from Cython/Python + NaryReduceNode(Graph&& expression, const std::vector& inputs, + const ArrayNode* output, const std::vector& initial_values, + const std::vector& operands); + + // Array overloads + double const* buff(const State& state) const override; + std::span diff(const State& state) const override; + ssize_t size(const State& state) const override; + std::span shape(const State& state) const override; + ssize_t size_diff(const State& state) const override; + SizeInfo sizeinfo() const override; + + // Information about the values are all inherited from the array + bool integral() const override; + double min() const override; + double max() const override; + + // Node overloads + void commit(State& state) const override; + void initialize_state(State& state) const override; + void propagate(State& state) const override; + void revert(State& state) const override; + + private: + double evaluate_expression(State& register_) const; + + const Graph expression_; + const std::vector inputs_; + const ArrayNode* output_; + const std::vector operands_; + const std::vector initial_values_; +}; + +} // namespace dwave::optimization diff --git a/dwave/optimization/include/dwave-optimization/utils.hpp b/dwave/optimization/include/dwave-optimization/utils.hpp index c0caa4a5..dad89339 100644 --- a/dwave/optimization/include/dwave-optimization/utils.hpp +++ b/dwave/optimization/include/dwave-optimization/utils.hpp @@ -16,6 +16,7 @@ #include #include +#include #include namespace dwave::optimization { @@ -166,4 +167,20 @@ void deduplicate_diff(std::vector& diff); // Return whether the given double encodes an integer. bool is_integer(const double& value); +template +bool is_variant(const node_type* node_ptr) { + // If the pointer can be dynamically cast to this type, return true + if (dynamic_cast(node_ptr)) { + return true; + } + + // If there are still types left to check then "recurse" + if constexpr (sizeof...(Ts) > 0) { + return is_variant(node_ptr); + } + + // If none match, then this Node didn't belong to the list of types + return false; +} + } // namespace dwave::optimization diff --git a/dwave/optimization/libcpp/graph.pxd b/dwave/optimization/libcpp/graph.pxd index 6c9d561f..153dcec7 100644 --- a/dwave/optimization/libcpp/graph.pxd +++ b/dwave/optimization/libcpp/graph.pxd @@ -23,6 +23,7 @@ from dwave.optimization.libcpp cimport span from dwave.optimization.libcpp.array cimport Array from dwave.optimization.libcpp.state cimport State + cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" nogil: cdef cppclass Node: struct SuccessorView: @@ -38,12 +39,20 @@ cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" cdef cppclass DecisionNode(Node): pass + cdef cppclass InputNode(Node, Array): + pass + # Sometimes Cython isn't able to reason about pointers as template inputs, so # we make a few aliases for convenience ctypedef Node* NodePtr ctypedef ArrayNode* ArrayNodePtr ctypedef DecisionNode* DecisionNodePtr +# This seems to be necessary to allow Cython to iterate over the returned +# span from `inputs()` directly. Otherwise it tries to cast it to a non-const +# version of span before iterating, which the C++ compiler will complain about. +ctypedef InputNode* const constInputNodePtr + cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" nogil: cdef cppclass Graph: T* emplace_node[T](...) except+ @@ -51,9 +60,11 @@ cdef extern from "dwave-optimization/graph.hpp" namespace "dwave::optimization" span[const unique_ptr[Node]] nodes() const span[const ArrayNodePtr] constraints() span[const DecisionNodePtr] decisions() + span[constInputNodePtr] inputs() + Py_ssize_t num_constraints() Py_ssize_t num_nodes() Py_ssize_t num_decisions() - Py_ssize_t num_constraints() + Py_ssize_t num_inputs() @staticmethod void recursive_initialize(State&, Node*) except+ @staticmethod diff --git a/dwave/optimization/libcpp/nodes.pxd b/dwave/optimization/libcpp/nodes.pxd index 048afd32..c719fcf6 100644 --- a/dwave/optimization/libcpp/nodes.pxd +++ b/dwave/optimization/libcpp/nodes.pxd @@ -56,6 +56,9 @@ cdef extern from "dwave-optimization/nodes/constants.hpp" namespace "dwave::opti cdef cppclass ConstantNode(ArrayNode): const double* buff() const + cdef cppclass InputNode(ArrayNode): + const double* buff() const + cdef extern from "dwave-optimization/nodes/flow.hpp" namespace "dwave::optimization" nogil: cdef cppclass WhereNode(ArrayNode): @@ -77,6 +80,11 @@ cdef extern from "dwave-optimization/nodes/indexing.hpp" namespace "dwave::optim pass +cdef extern from "dwave-optimization/nodes/lambda.hpp" namespace "dwave::optimization" nogil: + cdef cppclass NaryReduceNode(ArrayNode): + pass + + cdef extern from "dwave-optimization/nodes/manipulation.hpp" namespace "dwave::optimization" nogil: cdef cppclass ConcatenateNode(ArrayNode): Py_ssize_t axis() diff --git a/dwave/optimization/model.py b/dwave/optimization/model.py index 346197db..d6c48436 100644 --- a/dwave/optimization/model.py +++ b/dwave/optimization/model.py @@ -40,7 +40,61 @@ _ShapeLike: typing.TypeAlias = typing.Union[int, collections.abc.Sequence[int]] -__all__ = ["Model"] +__all__ = ["Expression", "Model"] + + +class Expression(_Graph): + def __init__( + self, + num_inputs: int = 0, + lower_bound: typing.Optional[float] = None, + upper_bound: typing.Optional[float] = None, + integral: typing.Optional[bool] = None, + ): + self.output: typing.Optional[ArraySymbol] = None + + if num_inputs > 0: + if (lower_bound is None or upper_bound is None or integral is None): + raise ValueError( + "`lower_bound`, `upper_bound` and `integral` must be provided " + "explicitly when initializing inputs" + ) + for _ in range(num_inputs): + self.input(lower_bound, upper_bound, integral) + + def input(self, lower_bound: float, upper_bound: float, integral: bool): + """TODO""" + # avoid circular import + from dwave.optimization.symbols import Input + # Shape is always scalar for now + return Input(self, lower_bound, upper_bound, integral, shape=tuple()) + + def set_output(self, value: ArraySymbol): + """TODO""" + self.output = value + + def constant(self, value: float) -> Constant: + # TODO: docstring + r"""Create a constant symbol. + + Args: + array_like: An |array-like|_ representing a constant. Can be a scalar + or a NumPy array. If the array's ``dtype`` is ``np.double``, the + array is not copied. + + Returns: + A constant symbol. + + Examples: + This example creates a :math:`1 \times 4`-sized constant symbol + with the specified values. + + >>> from dwave.optimization.model import Model + >>> model = Model() + >>> time_limits = model.constant([10, 15, 5, 8.5]) + """ + from dwave.optimization.symbols import Constant # avoid circular import + return Constant(self, value) @contextlib.contextmanager diff --git a/dwave/optimization/src/array.cpp b/dwave/optimization/src/array.cpp index d4fb85a7..4fdff92d 100644 --- a/dwave/optimization/src/array.cpp +++ b/dwave/optimization/src/array.cpp @@ -14,6 +14,8 @@ #include "dwave-optimization/array.hpp" +#include + namespace dwave::optimization { SizeInfo::SizeInfo(const Array* array_ptr, std::optional min, std::optional max) @@ -162,36 +164,47 @@ std::string shape_to_string(const std::span shape) { return out; } -bool array_shape_equal(const Array* lhs_ptr, const Array* rhs_ptr) { - auto lhs_size = lhs_ptr->sizeinfo(); - auto rhs_size = rhs_ptr->sizeinfo(); - - if (lhs_size == rhs_size) return true; - if (lhs_size.array_ptr == nullptr || rhs_size.array_ptr == nullptr) return false; - - // This first loop is redundant, but often we get diamond structures - // of predecessors so by going back together we might get to short circuit the - // dfs - while (lhs_size.array_ptr != lhs_ptr && rhs_size.array_ptr != rhs_ptr) { - lhs_ptr = lhs_size.array_ptr; - lhs_size = lhs_size.substitute(); - rhs_ptr = rhs_size.array_ptr; - rhs_size = rhs_size.substitute(); - if (lhs_size == rhs_size) return true; +bool array_shape_equal(const std::span array_ptrs) { + if (array_ptrs.size() == 0) { + return false; + } else if (array_ptrs.size() == 1) { + return true; } - while (lhs_size.array_ptr != lhs_ptr) { - lhs_ptr = lhs_size.array_ptr; - lhs_size = lhs_size.substitute(); - if (lhs_size == rhs_size) return true; + + const Array* first_ptr = array_ptrs[0]; + auto first_size = first_ptr->sizeinfo(); + while (first_size.array_ptr != nullptr && first_size.array_ptr != first_ptr) { + first_ptr = first_size.array_ptr; + first_size = first_size.substitute(); } - while (rhs_size.array_ptr != rhs_ptr) { - rhs_ptr = rhs_size.array_ptr; - rhs_size = rhs_size.substitute(); - if (lhs_size == rhs_size) return true; + + for (const Array* array_ptr : array_ptrs | std::views::take(1)) { + auto this_size = array_ptr->sizeinfo(); + + if (first_size == this_size) continue; + if (this_size.array_ptr == nullptr) return false; + + while (this_size.array_ptr != nullptr && this_size.array_ptr != array_ptr) { + array_ptr = this_size.array_ptr; + this_size = this_size.substitute(); + if (first_size == this_size) break; + } + + // Have to check again as it's possible that `this_size.array_ptr` is nullptr + if (first_size != this_size) return false; } - return false; + return true; +} + +bool array_shape_equal(const std::vector& array_ptrs) { + return array_shape_equal(std::span{array_ptrs}); } + +bool array_shape_equal(const Array* lhs_ptr, const Array* rhs_ptr) { + return array_shape_equal(std::array{lhs_ptr, rhs_ptr}); +} + bool array_shape_equal(const Array& lhs, const Array& rhs) { return array_shape_equal(&lhs, &rhs); } diff --git a/dwave/optimization/src/graph.cpp b/dwave/optimization/src/graph.cpp index 9084ad93..99e4f98c 100644 --- a/dwave/optimization/src/graph.cpp +++ b/dwave/optimization/src/graph.cpp @@ -28,6 +28,7 @@ namespace dwave::optimization { Graph::Graph() = default; Graph::~Graph() = default; +Graph::Graph(Graph&&) = default; void Graph::topological_sort() { if (topologically_sorted_) return; diff --git a/dwave/optimization/src/nodes/constants.cpp b/dwave/optimization/src/nodes/constants.cpp index d78d9a85..0a06b8bf 100644 --- a/dwave/optimization/src/nodes/constants.cpp +++ b/dwave/optimization/src/nodes/constants.cpp @@ -18,6 +18,7 @@ #include #include "dwave-optimization/utils.hpp" +#include "_state.hpp" namespace dwave::optimization { @@ -110,4 +111,5 @@ double ConstantNode::min() const { return buffer_stats_->min; } + } // namespace dwave::optimization diff --git a/dwave/optimization/src/nodes/lambda.cpp b/dwave/optimization/src/nodes/lambda.cpp new file mode 100644 index 00000000..27485dc3 --- /dev/null +++ b/dwave/optimization/src/nodes/lambda.cpp @@ -0,0 +1,300 @@ +// Copyright 2023 D-Wave Systems Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dwave-optimization/nodes/lambda.hpp" + +#include "_state.hpp" +#include "dwave-optimization/array.hpp" +#include "dwave-optimization/nodes/constants.hpp" +#include "dwave-optimization/nodes/mathematical.hpp" +#include "dwave-optimization/state.hpp" +#include "dwave-optimization/utils.hpp" + +namespace dwave::optimization { + +void InputNode::initialize_state(State& state, std::span data) const { + int index = this->topological_index(); + assert(index >= 0 && "must be topologically sorted"); + assert(static_cast(state.size()) > index && "unexpected state length"); + assert(state[index] == nullptr && "already initialized state"); + + if (static_cast(data.size()) != this->size()) { + throw std::invalid_argument("data size does not match size of InputNode"); + } + + std::vector copy(data.begin(), data.end()); + + state[index] = std::make_unique(std::move(copy)); +} + +double const* InputNode::buff(const State& state) const { + return data_ptr(state)->buff(); +} + +std::span InputNode::diff(const State& state) const noexcept { + return data_ptr(state)->diff(); +} + +void InputNode::commit(State& state) const noexcept { + data_ptr(state)->commit(); +} + +void InputNode::revert(State& state) const noexcept { + data_ptr(state)->revert(); +} + +void InputNode::assign(State& state, std::span new_values) const { + if (static_cast(new_values.size()) != this->size()) { + throw std::invalid_argument("size of new values must match"); + } + + double min_val = std::numeric_limits::infinity(); + double max_val = -std::numeric_limits::infinity(); + + static double dummy = 0; + bool all_is_integral = true; + for (const double& v : new_values) { + min_val = std::min(min_val, v); + max_val = std::max(min_val, v); + all_is_integral &= (std::modf(v, &dummy) == 0.0); + } + + if (min_val < min()) { + throw std::invalid_argument("new data contains a value smaller than the min"); + } + if (max_val > max()) { + throw std::invalid_argument("new data contains a value smaller than the min"); + } + if (integral() && !all_is_integral) { + throw std::invalid_argument("new data contains a non-integral value"); + } + + data_ptr(state)->assign(new_values); +} + +void InputNode::assign(State& state, const std::vector& new_values) const { + this->assign(state, std::span(new_values)); +} + +class NaryReduceNodeData : public ArrayNodeStateData { + public: + explicit NaryReduceNodeData(std::vector&& values, + std::vector&& iterators, State&& state) + : ArrayNodeStateData(std::move(values)), + iterators(std::move(iterators)), + register_(std::move(state)) {} + + // used to avoid reallocating memory for predecessor iterators every propagation + std::vector iterators; + + State register_; +}; + +Graph validate_expression(Graph&& expression, const std::vector inputs, + const ArrayNode* output) { + if (!expression.topologically_sorted()) { + throw std::invalid_argument("Expression must be topologically sorted"); + } + + if (expression.num_decisions()) { + // At least one decision, so the first node must be a decision + throw std::invalid_argument( + R"({"message": "Expression should not have any decision variables", "node_ptr": )" + + std::to_string((uintptr_t)(void*)expression.nodes()[0].get()) + "}"); + } + + for (const auto& node_ptr : expression.nodes()) { + const ArrayNode* array_node = dynamic_cast(node_ptr.get()); + if (!array_node) { + throw std::invalid_argument( + R"({"message": "Expression should contain only array nodes", "node_ptr": )" + + std::to_string((uintptr_t)(void*)node_ptr.get()) + "}"); + } + + if (!is_variant(array_node)) { + throw std::invalid_argument( + R"({"message": "Expression contains unsupported node", "node_ptr": )" + + std::to_string((uintptr_t)(void*)node_ptr.get()) + "}"); + } + + if (array_node->ndim() != 0) { + throw std::invalid_argument( + R"({"message": "Expression should only contain scalars", "node_ptr": )" + + std::to_string((uintptr_t)(void*)node_ptr.get()) + "}"); + } + } + + return expression; +} + +auto get_operands_shape(const std::vector& inputs, + const std::vector& initial_values, + const std::vector& operands) { + if (operands.size() == 0) { + throw std::invalid_argument("Must have at least one operand"); + } + + if (operands.size() + 1 != inputs.size()) { + throw std::invalid_argument("Expression must have one more InputNode than operands"); + } + + if (operands.size() + 1 != initial_values.size()) { + throw std::invalid_argument("Must have same number of initial values as operands"); + } + + std::vector array_ops; + for (const ArrayNode* op : operands) { + array_ops.push_back(op); + } + + if (!array_shape_equal(array_ops)) { + throw std::invalid_argument("All operands must have the same shape"); + } + + return operands[0]->shape(); +} + +NaryReduceNode::NaryReduceNode(Graph&& expression, const std::vector& inputs, + const ArrayNode* output, const std::vector& initial_values, + const std::vector& operands) + : ArrayOutputMixin(get_operands_shape(inputs, initial_values, operands)), + expression_(validate_expression(std::move(expression), inputs, output)), + inputs_(inputs), + output_(output), + operands_(operands), + initial_values_(initial_values) { + for (const auto& op : operands_) { + add_predecessor(op); + } +} + +double const* NaryReduceNode::buff(const State& state) const { + return data_ptr(state)->buffer.data(); +}; + +std::span NaryReduceNode::diff(const State& state) const { + return data_ptr(state)->diff(); +} + +ssize_t NaryReduceNode::size(const State& state) const { return operands_[0]->size(state); } + +std::span NaryReduceNode::shape(const State& state) const { + return operands_[0]->shape(state); +} + +ssize_t NaryReduceNode::size_diff(const State& state) const { + return data_ptr(state)->size_diff(); +} + +SizeInfo NaryReduceNode::sizeinfo() const { return operands_[0]->sizeinfo(); } + +bool NaryReduceNode::integral() const { return false; } + +double NaryReduceNode::min() const { return -std::numeric_limits::infinity(); } + +double NaryReduceNode::max() const { return std::numeric_limits::infinity(); } + +void NaryReduceNode::commit(State& state) const { data_ptr(state)->commit(); } + +double NaryReduceNode::evaluate_expression(State& register_) const { + // First propagate all the nodes + for (const auto& node_ptr : expression_.nodes()) { + node_ptr->propagate(register_); + } + // Then commit to clear the diffs + for (const auto& node_ptr : expression_.nodes()) { + node_ptr->commit(register_); + } + return output_->view(register_)[0]; +} + +void NaryReduceNode::initialize_state(State& state) const { + int node_idx = topological_index(); + assert(node_idx >= 0 && "must be topologically sorted"); + assert(state[node_idx] == nullptr && "already initialized state"); + + ssize_t start_size = this->size(state); + ssize_t num_args = operands_.size(); + std::vector values; + State reg; + reg = expression_.empty_state(); + + std::vector iterators; + for (const ArrayNode* array_ptr : operands_) { + iterators.push_back(array_ptr->begin(state)); + } + + // Get the initial output of the expression + for (ssize_t inp_index = 0; inp_index < num_args + 1; inp_index++) { + inputs_[inp_index]->initialize_state(reg, std::span(initial_values_).subspan(inp_index, 1)); + } + + // Finish the initialization after the input states have been set + expression_.initialize_state(reg); + + double val = evaluate_expression(reg); + + // Compute the expression for each subsequent index + for (ssize_t index = 0; index < start_size; ++index) { + for (ssize_t arg_index = 0; arg_index < num_args; ++arg_index) { + double input_val = *iterators[arg_index]; + inputs_[arg_index]->assign(reg, std::span(&input_val, 1)); + iterators[arg_index]++; + } + // Final input comes from the previous expression + inputs_[num_args]->assign(reg, std::span(&val, 1)); + val = evaluate_expression(reg); + values.push_back(val); + } + + state[node_idx] = std::make_unique(std::move(values), std::move(iterators), + std::move(reg)); +} + +void NaryReduceNode::propagate(State& state) const { + NaryReduceNodeData* data = data_ptr(state); + ssize_t new_size = this->size(state); + ssize_t num_args = operands_.size(); + + data->iterators.clear(); + for (const ArrayNode* array_ptr : operands_) { + data->iterators.push_back(array_ptr->begin(state)); + } + + // Set inputs to the initial values + for (ssize_t inp_index = 0; inp_index < num_args + 1; inp_index++) { + inputs_[inp_index]->assign(data->register_, + std::span(initial_values_).subspan(inp_index, 1)); + } + double val = evaluate_expression(data->register_); + + for (ssize_t index = 0; index < new_size; ++index) { + for (ssize_t arg_index = 0; arg_index < num_args; ++arg_index) { + double arg_val = *data->iterators[arg_index]; + inputs_[arg_index]->assign(data->register_, std::span(&arg_val, 1)); + data->iterators[arg_index]++; + } + // Final input comes from the previous expression + inputs_[num_args]->assign(data->register_, std::span(&val, 1)); + val = evaluate_expression(data->register_); + data->set(index, val); + } + + if (data->diff().size()) Node::propagate(state); +} + +void NaryReduceNode::revert(State& state) const { data_ptr(state)->revert(); } + +} // namespace dwave::optimization diff --git a/dwave/optimization/symbols.pyi b/dwave/optimization/symbols.pyi index 3f2cb583..b127ced6 100644 --- a/dwave/optimization/symbols.pyi +++ b/dwave/optimization/symbols.pyi @@ -98,6 +98,10 @@ class Equal(ArraySymbol): ... +class Input(ArraySymbol): + ... + + class IntegerVariable(ArraySymbol): def lower_bound(self) -> float: ... def set_state(self, index: int, state: numpy.typing.ArrayLike): ... @@ -156,6 +160,10 @@ class NaryMultiply(ArraySymbol): ... +class NaryReduce(ArraySymbol): + ... + + class Negative(ArraySymbol): ... diff --git a/dwave/optimization/symbols.pyx b/dwave/optimization/symbols.pyx index ce698123..d78b7cc7 100644 --- a/dwave/optimization/symbols.pyx +++ b/dwave/optimization/symbols.pyx @@ -23,13 +23,15 @@ import numbers cimport cpython.object import cython import numpy as np +from typing import Collection from cpython.ref cimport PyObject from cython.operator cimport dereference as deref, typeid from libc.math cimport modf from libcpp cimport bool -from libcpp.cast cimport dynamic_cast +from libcpp.cast cimport dynamic_cast, reinterpret_cast from libcpp.optional cimport nullopt, optional +from libc.stdint cimport uintptr_t from libcpp.typeindex cimport type_index from libcpp.unordered_map cimport unordered_map from libcpp.utility cimport move @@ -45,6 +47,7 @@ from dwave.optimization.libcpp.graph cimport ( ArrayNode as cppArrayNode, ArrayNodePtr as cppArrayNodePtr, Node as cppNode, + NodePtr as cppNodePtr, ) from dwave.optimization.libcpp.nodes cimport ( AbsoluteNode as cppAbsoluteNode, @@ -64,6 +67,7 @@ from dwave.optimization.libcpp.nodes cimport ( DisjointListsNode as cppDisjointListsNode, DivideNode as cppDivideNode, EqualNode as cppEqualNode, + InputNode as cppInputNode, IntegerNode as cppIntegerNode, LessEqualNode as cppLessEqualNode, ListNode as cppListNode, @@ -78,6 +82,7 @@ from dwave.optimization.libcpp.nodes cimport ( NaryMaximumNode as cppNaryMaximumNode, NaryMinimumNode as cppNaryMinimumNode, NaryMultiplyNode as cppNaryMultiplyNode, + NaryReduceNode as cppNaryReduceNode, NegativeNode as cppNegativeNode, NotNode as cppNotNode, OrNode as cppOrNode, @@ -97,6 +102,7 @@ from dwave.optimization.libcpp.nodes cimport ( XorNode as cppXorNode, ) from dwave.optimization.model cimport ArraySymbol, _Graph, Symbol +from dwave.optimization.model import Expression from dwave.optimization.states cimport States __all__ = [ @@ -117,6 +123,7 @@ __all__ = [ "DisjointList", "Divide", "Equal", + "Input", "IntegerVariable", "LessEqual", "ListVariable", @@ -131,6 +138,7 @@ __all__ = [ "NaryMaximum", "NaryMinimum", "NaryMultiply", + "NaryReduce", "Negative", "Not", "Or", @@ -1633,6 +1641,37 @@ cdef class Equal(ArraySymbol): _register(Equal, typeid(cppEqualNode)) +cdef class Input(ArraySymbol): + """TODO""" + + # TODO: implement serialization + + def __init__(self, expression: Expression, lower_bound: float, upper_bound: float, integral: bool, shape: Optional[tuple] = None): + cdef vector[Py_ssize_t] vshape = _as_cppshape(tuple() if shape is None else shape) + + cdef _Graph cygraph = expression + + # Get an observing pointer to the C++ InputNode + self.ptr = cygraph._graph.emplace_node[cppInputNode](vshape, lower_bound, upper_bound, integral) + + self.initialize_arraynode(expression, self.ptr) + + @staticmethod + def _from_symbol(Symbol symbol): + cdef cppInputNode* ptr = dynamic_cast_ptr[cppInputNode](symbol.node_ptr) + if not ptr: + raise TypeError("given symbol cannot be used to construct a Input") + + cdef Input inp = Input.__new__(Input) + inp.ptr = ptr + inp.initialize_arraynode(symbol.model, ptr) + return inp + + cdef cppInputNode* ptr + +_register(Input, typeid(cppInputNode)) + + cdef class IntegerVariable(ArraySymbol): """Integer decision-variable symbol. @@ -1646,6 +1685,9 @@ cdef class IntegerVariable(ArraySymbol): """ def __init__(self, _Graph model, shape=None, lower_bound=None, upper_bound=None): + if isinstance(model, Expression): + raise TypeError("cannot add IntegerVariable to Expression") + cdef vector[Py_ssize_t] vshape = _as_cppshape(tuple() if shape is None else shape ) if lower_bound is None and upper_bound is None: @@ -2327,6 +2369,101 @@ cdef class NaryMultiply(ArraySymbol): _register(NaryMultiply, typeid(cppNaryMultiplyNode)) +# TODO: consider different location for this? +class UnsupportedNaryReduceExpression(Exception): + def __init__(self, message: str, symbol: Symbol): + super().__init__(message) + self.symbol = symbol + + +cdef class NaryReduce(ArraySymbol): + """TODO""" + + # TODO: implement serialization + + def __init__( + self, + # input_symbols: Collection[Input], + # ArraySymbol output_symbol, + expression: Expression, + operands: Collection[ArraySymbol], + initial_values: Optional[tuple[float]] = None, + ): + if len(operands) == 0: + raise ValueError("must have at least one operand") + + if expression.num_inputs() != len(operands) + 1: + raise ValueError("must have exactly one more input than number of operands") + + if initial_values is None: + initial_values = (0,) * expression.num_inputs() + + if len(initial_values) != expression.num_inputs(): + raise ValueError("must have same number of initial values as inputs") + + cdef _Graph graph = expression + cdef ArraySymbol output_symbol = expression.output + + cdef _Graph model = operands[0].model + cdef cppArrayNode* output = output_symbol.array_ptr + cdef vector[double] cppinitial_values + cdef cppInputNode* cppinput + cdef vector[cppInputNode*] cppinputs + cdef vector[cppArrayNode*] cppoperands + + for val in initial_values: + cppinitial_values.push_back(val) + + for cppinput in graph._graph.inputs(): + cppinputs.push_back(cppinput) + + cdef ArraySymbol array + for node in operands: + if node.model != model: + raise ValueError("all predecessors must be from the same model") + array = node + cppoperands.push_back(array.array_ptr) + + expression.lock() + try: + self.ptr = model._graph.emplace_node[cppNaryReduceNode]( + move(graph._graph), cppinputs, output, cppinitial_values, cppoperands + ) + except ValueError as e: + expression.unlock() + raise self._handle_unsupported_expression_exception(expression, e) + + self.initialize_arraynode(model, self.ptr) + + def _handle_unsupported_expression_exception(self, expression: Expression, exception): + try: + info = json.loads(str(exception)) + except json.JSONDecodeError: + raise RuntimeError("could not parse exception message from NaryReduceNode") + + cdef uintptr_t node_ptr_val = info["node_ptr"] + cdef cppNode* node_ptr = reinterpret_cast[cppNodePtr](node_ptr_val) + cdef Symbol symbol = symbol_from_ptr(expression, node_ptr) + e = UnsupportedNaryReduceExpression(info["message"], symbol) + return e + + @staticmethod + def _from_symbol(Symbol symbol): + cdef cppNaryReduceNode* ptr = dynamic_cast_ptr[cppNaryReduceNode]( + symbol.node_ptr + ) + if not ptr: + raise TypeError("given symbol cannot be used to construct an NaryReduce") + cdef NaryReduce x = NaryReduce.__new__(NaryReduce) + x.ptr = ptr + x.initialize_arraynode(symbol.model, ptr) + return x + + cdef cppNaryReduceNode* ptr + +_register(NaryReduce, typeid(cppNaryReduceNode)) + + cdef class Negative(ArraySymbol): """Numerical negative element-wise on a symbol. @@ -2788,6 +2925,9 @@ cdef class Reshape(ArraySymbol): """ def __init__(self, ArraySymbol node, shape): + if isinstance(node.model, Expression): + raise TypeError("cannot reshape symbol that belongs to `Expression`") + cdef _Graph model = node.model self.ptr = model._graph.emplace_node[cppReshapeNode]( diff --git a/meson.build b/meson.build index dbeeeee1..3852b1b3 100644 --- a/meson.build +++ b/meson.build @@ -29,6 +29,7 @@ dwave_optimization_src = [ 'dwave/optimization/src/nodes/constants.cpp', 'dwave/optimization/src/nodes/flow.cpp', 'dwave/optimization/src/nodes/indexing.cpp', + 'dwave/optimization/src/nodes/lambda.cpp', 'dwave/optimization/src/nodes/manipulation.cpp', 'dwave/optimization/src/nodes/mathematical.cpp', 'dwave/optimization/src/nodes/numbers.cpp', diff --git a/tests/cpp/meson.build b/tests/cpp/meson.build index 81292217..ea70b7ff 100644 --- a/tests/cpp/meson.build +++ b/tests/cpp/meson.build @@ -16,6 +16,7 @@ tests_all = executable( 'nodes/test_collections.cpp', 'nodes/test_constants.cpp', 'nodes/test_flow.cpp', + 'nodes/test_lambda.cpp', 'nodes/test_manipulation.cpp', 'nodes/test_numbers.cpp', 'nodes/test_quadratic_model.cpp', diff --git a/tests/cpp/nodes/test_constants.cpp b/tests/cpp/nodes/test_constants.cpp index e0bf8dc0..366fcf4d 100644 --- a/tests/cpp/nodes/test_constants.cpp +++ b/tests/cpp/nodes/test_constants.cpp @@ -15,6 +15,7 @@ #include "catch2/catch_test_macros.hpp" #include "dwave-optimization/graph.hpp" #include "dwave-optimization/nodes/constants.hpp" +#include "dwave-optimization/nodes/testing.hpp" namespace dwave::optimization { diff --git a/tests/cpp/nodes/test_lambda.cpp b/tests/cpp/nodes/test_lambda.cpp new file mode 100644 index 00000000..02f41001 --- /dev/null +++ b/tests/cpp/nodes/test_lambda.cpp @@ -0,0 +1,289 @@ +// Copyright 2024 D-Wave Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace dwave::optimization { + +TEST_CASE("InputNode") { + auto graph = Graph(); + + GIVEN("An input node starting with state copied from a vector") { + auto ptr = graph.emplace_node(std::vector{4}, 10, 50, true); + auto val = graph.emplace_node(ptr); + + THEN("It copies the values into a 1d array") { + CHECK(ptr->ndim() == 1); + CHECK(ptr->size() == 4); + CHECK(std::ranges::equal(ptr->shape(), std::vector{4})); + CHECK(std::ranges::equal(ptr->strides(), std::vector{sizeof(double)})); + } + + THEN("min/max/integral are set from arguments") { + CHECK(ptr->min() == 10); + CHECK(ptr->max() == 50); + CHECK(ptr->integral()); + } + + THEN("initializing the graph state (without initializing the InputNode) throws an error") { + CHECK_THROWS(graph.initialize_state()); + } + + AND_GIVEN("An initialized state") { + std::vector values = {30, 10, 40, 20}; + auto state = graph.empty_state(); + ptr->initialize_state(state, std::span{values}); + graph.initialize_state(state); + + THEN("The state defaults to the values from the vector") { + CHECK(std::ranges::equal(ptr->view(state), values)); + } + + AND_WHEN("We assign new values and propagate") { + std::vector new_values = {20, 10, 49, 50}; + ptr->assign(state, new_values); + + ptr->propagate(state); + val->propagate(state); + + THEN("The InputNode has the new values") { + CHECK(std::ranges::equal(ptr->view(state), new_values)); + } + + THEN("We can commit") { + ptr->commit(state); + val->commit(state); + } + + THEN("We can revert") { + ptr->revert(state); + val->revert(state); + } + } + + AND_WHEN("We assign invalid values we get an exception") { + std::vector new_values = {20, 10, 49, 51}; + CHECK_THROWS(ptr->assign(state, new_values)); + + new_values = {9, 9, 9, 9}; + CHECK_THROWS(ptr->assign(state, new_values)); + + new_values = {9.99, 50.01, 25, 25}; + CHECK_THROWS(ptr->assign(state, new_values)); + + new_values = {20, 20, 20}; + CHECK_THROWS(ptr->assign(state, new_values)); + + new_values = {20, 20, 20, 20, 20}; + CHECK_THROWS(ptr->assign(state, new_values)); + } + } + } +} + +TEST_CASE("NaryReduceNode") { + auto graph = Graph(); + + GIVEN("A vector of constants and an expression") { + std::vector i = {0, 1, 2, 2}; + std::vector j = {1, 2, 4, 3}; + + auto args = std::vector{graph.emplace_node(i), + graph.emplace_node(j)}; + + // x0 * x1 + x2 + auto expression = Graph(); + std::vector inputs = {expression.emplace_node(), + expression.emplace_node(), + expression.emplace_node()}; + auto output_ptr = expression.emplace_node( + expression.emplace_node(inputs[0], inputs[1]), inputs[2]); + expression.topological_sort(); + + THEN("We can create an accumulate node") { + std::vector initial_values({1, 2, 3}); + auto reduce_ptr = graph.emplace_node(std::move(expression), inputs, + output_ptr, initial_values, args); + + AND_WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The state is correct") { + CHECK(std::ranges::equal(reduce_ptr->view(state), std::vector{5, 7, 15, 21})); + } + } + } + } + + GIVEN("Two integer nodes and a more complicated expression") { + auto i_ptr = graph.emplace_node(std::initializer_list{5}, -10, 10); + auto j_ptr = graph.emplace_node(std::initializer_list{5}, -10, 10); + + // (x0 + 1) * x1 - x2 + 5 + auto expression = Graph(); + std::vector inputs = {expression.emplace_node(), + expression.emplace_node(), + expression.emplace_node()}; + auto output_ptr = expression.emplace_node( + expression.emplace_node( + expression.emplace_node( + expression.emplace_node( + inputs[0], expression.emplace_node(1)), + inputs[1]), + inputs[2]), + expression.emplace_node(5)); + expression.topological_sort(); + + THEN("We can create a lambda node") { + std::vector args({i_ptr, j_ptr}); + + std::vector initial_values({1, 2, 3}); + auto reduce_ptr = graph.emplace_node(std::move(expression), inputs, + output_ptr, initial_values, args); + + auto validation_ptr = graph.emplace_node(reduce_ptr); + + AND_WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The state is correct") { + CHECK(std::ranges::equal(reduce_ptr->view(state), + std::vector{-1, 6, -1, 6, -1})); + } + + AND_WHEN("We mutate the integers and propagate") { + i_ptr->set_value(state, 4, 4); + j_ptr->set_value(state, 4, 5); + + i_ptr->set_value(state, 0, 3); + j_ptr->set_value(state, 0, -1); + + j_ptr->set_value(state, 3, 7); + + i_ptr->propagate(state); // [3, 0, 0, 0, 4] + j_ptr->propagate(state); // [-1, 0, 0, 7, 5] + reduce_ptr->propagate(state); + validation_ptr->propagate(state); + + THEN("The state is correct") { + CHECK(std::ranges::equal(reduce_ptr->view(state), + std::vector{-5, 10, -5, 17, 13})); + } + } + } + } + } + + GIVEN("Three integer nodes and an expression") { + auto i_ptr = graph.emplace_node(std::initializer_list{5}, 0, 100); + auto j_ptr = graph.emplace_node(std::initializer_list{5}, 0, 100); + + // max(x0 + x2, x1) + auto expression = Graph(); + std::vector inputs = {expression.emplace_node(), + expression.emplace_node(), + expression.emplace_node()}; + auto output_ptr = expression.emplace_node( + expression.emplace_node(inputs[0], inputs[2]), inputs[1]); + expression.topological_sort(); + + THEN("We can create a lambda node with basic functions and logic control") { + std::vector args({i_ptr, j_ptr}); + + std::vector initial_values({0, 0, 0}); + auto reduce_ptr = graph.emplace_node(std::move(expression), inputs, + output_ptr, initial_values, args); + + auto validation_ptr = graph.emplace_node(reduce_ptr); + + AND_WHEN("We initialize a state") { + auto state = graph.empty_state(); + i_ptr->initialize_state(state, {0, 1, 2, 3, 4}); + j_ptr->initialize_state(state, {10, 10, 20, 30, 32}); + + graph.initialize_state(state); + + THEN("The state is correct") { + CHECK(std::ranges::equal(reduce_ptr->view(state), + std::vector{10, 11, 20, 30, 34})); + } + + AND_WHEN("We mutate the integers and propagate") { + i_ptr->set_value(state, 4, 5); + + j_ptr->set_value(state, 1, 15); + j_ptr->set_value(state, 2, 15); + + i_ptr->propagate(state); // [0, 1, 2, 3, 5] + j_ptr->propagate(state); // [10, 15, 15, 30, 32] + reduce_ptr->propagate(state); + validation_ptr->propagate(state); + + THEN("The state is correct") { + CHECK(std::ranges::equal(reduce_ptr->view(state), + std::vector{10, 15, 17, 30, 35})); + } + } + } + } + } + + GIVEN("A constant node") { + std::vector i = {0, 1, 2, 2}; + + std::vector initial_values({1}); + auto args = std::vector{graph.emplace_node(i)}; + + THEN("We can't create a NaryReduceNode with an expression with decision variables") { + auto expression = Graph(); + std::vector inputs = { + expression.emplace_node(), + }; + auto output_ptr = expression.emplace_node( + inputs[0], expression.emplace_node()); + expression.topological_sort(); + + CHECK_THROWS(graph.emplace_node(std::move(expression), inputs, + output_ptr, initial_values, args)); + } + + THEN("We can't create a NaryReduceNode with non-scalar nodes") { + auto expression = Graph(); + std::vector inputs = { + expression.emplace_node(std::vector{2}, 0, 1, false), + }; + auto output_ptr = expression.emplace_node( + inputs[0], + expression.emplace_node(std::initializer_list{2})); + expression.topological_sort(); + + CHECK_THROWS(graph.emplace_node(std::move(expression), inputs, + output_ptr, initial_values, args)); + } + + // TODO: num initial values + // TODO: num args + // TODO: unsupported nodes + } +} + +} // namespace dwave::optimization diff --git a/tests/test_expression.py b/tests/test_expression.py new file mode 100644 index 00000000..25a0badc --- /dev/null +++ b/tests/test_expression.py @@ -0,0 +1,94 @@ +# Copyright 2024 D-Wave Systems Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import dwave.optimization.symbols +from dwave.optimization.model import Expression + + +class TestExpression(unittest.TestCase): + def test(self): + Expression() + + def test_initial_inputs(self): + exp = Expression(num_inputs=10, lower_bound=-5, upper_bound=3.7, integral=False) + self.assertEqual(exp.num_inputs(), 10) + + # Test that all arguments must be provided if starting with initial inputs + with self.assertRaises(ValueError): + Expression(num_inputs=10, upper_bound=3.7, integral=False) + with self.assertRaises(ValueError): + Expression(num_inputs=10, lower_bound=-5, integral=False) + with self.assertRaises(ValueError): + Expression(num_inputs=10, lower_bound=-5, upper_bound=3.7) + + def test_unsupported_symbols(self): + # Can't add decisions to an Expression, even manually + exp = Expression() + with self.assertRaises(TypeError): + dwave.optimization.symbols.IntegerVariable(exp) + + # Can't add other symbols, e.g. Reshape + exp = Expression() + inp = exp.input(0, 1, False) + with self.assertRaises(TypeError): + dwave.optimization.symbols.Reshape(inp, (1, 1, 1)) + + def test_num_inputs(self): + exp = Expression() + self.assertEqual(exp.num_inputs(), 0) + + inp0 = exp.input(-1, 1, True) + self.assertEqual(exp.num_inputs(), 1) + + inp1 = exp.input(-1, 1, True) + self.assertEqual(exp.num_inputs(), 2) + + inp0 + inp1 + self.assertEqual(exp.num_inputs(), 2) + self.assertEqual(exp.num_nodes(), 3) + + exp.input(-1, 1, True) + self.assertEqual(exp.num_inputs(), 3) + self.assertEqual(exp.num_nodes(), 4) + + def test_iter_inputs(self): + exp = Expression() + self.assertListEqual(list(exp.iter_inputs()), []) + + inp0 = exp.input(-1, 1, True) + symbols = list(exp.iter_inputs()) + self.assertEqual(len(symbols), 1) + self.assertTrue(inp0.equals(symbols[0])) + + inp1 = exp.input(-1, 1, True) + symbols = list(exp.iter_inputs()) + self.assertEqual(len(symbols), 2) + self.assertTrue(inp0.equals(symbols[0])) + self.assertTrue(inp1.equals(symbols[1])) + + inp0 + inp1 + symbols = list(exp.iter_inputs()) + self.assertEqual(len(symbols), 2) + + inp2 = exp.input(-1, 1, True) + symbols = list(exp.iter_inputs()) + self.assertEqual(len(symbols), 3) + self.assertTrue(inp2.equals(symbols[2])) + + def test_constants(self): + exp = Expression() + c0, c1 = exp.constant(5), exp.constant(-7.5) + c0 + c1 diff --git a/tests/test_symbols.py b/tests/test_symbols.py index e31fcf16..19cc208a 100644 --- a/tests/test_symbols.py +++ b/tests/test_symbols.py @@ -35,6 +35,7 @@ mod, sqrt, ) +from dwave.optimization.model import Expression class utils: @@ -1213,6 +1214,23 @@ def test_unsupported_floor_division(self): a // b +class TestInput(utils.SymbolTests): + def generate_symbols(self): + exp = Expression() + inp = exp.input(-10, 10, False) + exp.lock() + yield inp + + # TODO: enable once implemented + @unittest.skip("not yet implemented") + def test_serialization(*args, **kwargs): + pass + + @unittest.skip("Input state must be explicity initialized so can't run this test") + def test_state_serialization(*args, **kwargs): + pass + + class TestIntegerVariable(utils.SymbolTests): def generate_symbols(self): model = Model() @@ -1857,6 +1875,64 @@ def test_mismatched_shape(self): x *= b # after promotion +class TestNaryReduce(utils.SymbolTests): + def generate_symbols(self): + model = Model() + c0 = model.constant([0, 0]) + c1 = model.constant([0, 1]) + + exp = Expression() + inputs = [exp.input(-10, 10, False) for _ in range(3)] + exp.set_output(inputs[0] + inputs[1] + inputs[2]) + + acc = dwave.optimization.symbols.NaryReduce(exp, (c0, c1)) + + model.lock() + yield acc + + def test_mismatched_inputs(self): + model = Model() + c0 = model.constant([0, 0]) + c1 = model.constant([0, 1]) + + exp = Expression() + inputs = [exp.input(-10, 10, False) for _ in range(3)] + exp.set_output(inputs[0] + inputs[1] + inputs[2]) + + with self.assertRaises(ValueError): + dwave.optimization.symbols.NaryReduce(exp, (c0,)) + + with self.assertRaises(ValueError): + dwave.optimization.symbols.NaryReduce(exp, (c0, c1), initial_values=(0,)) + + def test_invalid_expressions(self): + model = Model() + c0 = model.constant([0, 0]) + + # Can't use an Expression that uses a non-scalar input + exp = Expression() + inp1 = exp.input(-10, 10, False) + inp5 = dwave.optimization.symbols.Input(exp, -10, 10, False, (5,)) + exp.set_output(inp1) + try: + dwave.optimization.symbols.NaryReduce(exp, (c0,)) + self.assertTrue(False, "should raise exception") + except Exception as e: + self.assertIsInstance(e, dwave.optimization.symbols.UnsupportedNaryReduceExpression) + self.assertRegex(str(e), "scalar") + self.assertTrue(inp5.equals(e.symbol)) + + # TODO: enable once implemented + @unittest.skip("not yet implemented") + def test_serialization(*args, **kwargs): + pass + + # TODO: enable once implemented + @unittest.skip("not yet implemented") + def test_state_serialization(*args, **kwargs): + pass + + class TestNegate(utils.UnaryOpTests): def op(self, x): return -x