Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix QPU failover - signal solver failover condition (on resolve) #465

Merged
merged 11 commits into from
Oct 5, 2022
93 changes: 73 additions & 20 deletions dwave/system/samplers/dwave_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
from dimod.exceptions import BinaryQuadraticModelStructureError
from dwave.cloud import Client
from dwave.cloud.exceptions import (
SolverOfflineError, SolverNotFoundError, ProblemStructureError)
SolverError, SolverAuthenticationError, InvalidAPIResponseError,
RequestTimeout, PollingTimeout, ProblemUploadError,
SolverOfflineError, SolverNotFoundError, ProblemStructureError,
)

from dwave.system.warnings import WarningHandler, WarningAction

Expand All @@ -40,6 +43,17 @@
__all__ = ['DWaveSampler', 'qpu_graph']


class FailoverCondition(Exception):
"""QPU SolverAPI call failed with an error that might be mitigated by
retrying on a different solver.
"""

class RetryCondition(FailoverCondition):
"""QPU SolverAPI call failed with an error that might be mitigated by
retrying on the same solver.
"""


def qpu_graph(topology_type, topology_shape, nodelist, edgelist):
"""Converts node and edge lists to a dwave-networkx compatible graph.

Expand Down Expand Up @@ -85,6 +99,7 @@ def qpu_graph(topology_type, topology_shape, nodelist, edgelist):
'QPU architecure')
return G


def _failover(f):
"""Decorator for methods that might raise SolverOfflineError. Assumes that
the method is on a class with a `trigger_failover` method and a truthy
Expand Down Expand Up @@ -124,17 +139,29 @@ class DWaveSampler(dimod.Sampler, dimod.Structured):

Args:
failover (bool, optional, default=False):
Switch to a new QPU in the rare event that the currently connected
system goes offline. Note that different QPUs may have different
hardware graphs and a failover will result in a regenerated
:attr:`.nodelist`, :attr:`.edgelist`, :attr:`.properties` and
:attr:`.parameters`.
Set to ``True`` in order to signal a failover condition on sampling error.
randomir marked this conversation as resolved.
Show resolved Hide resolved
Failover is signalled by raising :exc:`.FailoverCondition` or
:exc:`.RetryCondition` on sampleset resolve.

Actual failover, i.e. selection of a new solver, has to be handled
by the user. A convenience method :meth:`.trigger_failover` is available
for this. Note that different QPUs may have different hardware graphs and a
failover will result in a regenerated :attr:`.nodelist`, :attr:`.edgelist`,
:attr:`.properties` and :attr:`.parameters`.

.. versionchanged:: 1.16.0

Some time ago, in the era of blocking :meth:`sample` response,
``failover=True`` would cause QPU/solver failover and sampling
retry. However, ever since :meth:`sample` is non-blocking/async,
failover is broken, i.e. setting ``failover=True`` does nothing.

retry_interval (number, optional, default=-1):
The amount of time (in seconds) to wait to poll for a solver in
the case that no solver is found. If `retry_interval` is negative
then it will instead propogate the `SolverNotFoundError` to the
user.
Ignored, but kept for backward compatibility.

.. versionchanged:: 1.16.0

Ignored since 1.16.0. See note for ``failover`` parameter above.

**config:
Keyword arguments passed to :meth:`dwave.cloud.client.Client.from_config`.
Expand Down Expand Up @@ -314,7 +341,7 @@ def trigger_failover(self):

# the requested features are saved on the client object, so
# we just need to request a new solver
self.solver = self.client.get_solver()
self.solver = self.client.get_solver(refresh=True)

# delete the lazily-constructed attributes
try:
Expand All @@ -337,7 +364,6 @@ def trigger_failover(self):
except AttributeError:
pass

@_failover
def sample(self, bqm, warnings=None, **kwargs):
"""Sample from the specified binary quadratic model.

Expand Down Expand Up @@ -404,17 +430,44 @@ def sample(self, bqm, warnings=None, **kwargs):
warninghandler = WarningHandler(warnings)
warninghandler.energy_scale(bqm)

# need a hook so that we can check the sampleset (lazily) for
# warnings
# need a hook so that we can lazily check the sampleset for warnings
# and handle failover consistently
def _hook(computation):
sampleset = computation.sampleset
def resolve(computation):
sampleset = computation.sampleset
sampleset.resolve()

if warninghandler is not None:
warninghandler.too_few_samples(sampleset)
if warninghandler.action is WarningAction.SAVE:
sampleset.info['warnings'] = warninghandler.saved

if warninghandler is not None:
warninghandler.too_few_samples(sampleset)
if warninghandler.action is WarningAction.SAVE:
sampleset.info['warnings'] = warninghandler.saved
return sampleset

return sampleset
try:
return resolve(computation)

except (ProblemUploadError, RequestTimeout, PollingTimeout) as exc:
if not self.failover:
raise exc

# failover with retry on:
# - request or polling timeout
# - upload errors
raise RetryCondition("resubmit problem") from exc

except (SolverError, InvalidAPIResponseError) as exc:
if not self.failover:
raise exc
if isinstance(exc, SolverAuthenticationError):
raise exc

# failover on:
# - solver offline, solver disabled or not found
# - internal SAPI errors (like malformed response)
# - generic solver errors
# but NOT on auth errors
raise FailoverCondition("switch solver and resubmit problem") from exc

return dimod.SampleSet.from_future(future, _hook)

Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
parameterized==0.7.4
networkx>=2.6

coverage
codecov
75 changes: 34 additions & 41 deletions tests/test_dwavesampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,25 @@

import sys
import unittest
import random
import warnings

from collections import namedtuple
from concurrent.futures import Future
from unittest import mock
from uuid import uuid4

import numpy as np
from networkx.utils import graphs_equal
from parameterized import parameterized

import dimod
import dwave_networkx as dnx

from dwave.cloud.exceptions import SolverOfflineError, SolverNotFoundError
from dwave.cloud import exceptions
from dwave.cloud import computation

from dwave.system.samplers import DWaveSampler, qpu_graph
from dwave.system.warnings import EnergyScaleWarning, TooFewSamplesWarning
from dwave.system.samplers.dwave_sampler import RetryCondition, FailoverCondition

from networkx.utils.misc import graphs_equal

C16 = dnx.chimera_graph(16)

Expand Down Expand Up @@ -94,7 +94,7 @@ def sample_bqm(self, bqm, num_reads=1, **kwargs):
return future


class TestDwaveSampler(unittest.TestCase):
class TestDWaveSampler(unittest.TestCase):
@mock.patch('dwave.system.samplers.dwave_sampler.Client')
def setUp(self, MockClient):

Expand Down Expand Up @@ -203,28 +203,32 @@ def test_problem_labelling(self):
self.assertEqual(ss.info.get('problem_label'), label)

@mock.patch('dwave.system.samplers.dwave_sampler.Client')
def test_failover_false(self, MockClient):
def test_failover_off(self, MockClient):
sampler = DWaveSampler(failover=False)

sampler.solver.sample_ising.side_effect = SolverOfflineError
sampler.solver.sample_qubo.side_effect = SolverOfflineError
sampler.solver.sample_bqm.side_effect = SolverOfflineError
sampler.solver.sample_bqm.side_effect = exceptions.SolverOfflineError

with self.assertRaises(SolverOfflineError):
with self.assertRaises(exceptions.SolverOfflineError):
sampler.sample_ising({}, {})

@parameterized.expand([
(exceptions.InvalidAPIResponseError, FailoverCondition),
(exceptions.SolverNotFoundError, FailoverCondition),
(exceptions.SolverOfflineError, FailoverCondition),
(exceptions.SolverError, FailoverCondition),
(exceptions.PollingTimeout, RetryCondition),
(exceptions.SolverAuthenticationError, exceptions.SolverAuthenticationError), # auth error propagated
(KeyError, KeyError), # unrelated errors propagated
])
@mock.patch('dwave.system.samplers.dwave_sampler.Client')
def test_failover_offline(self, MockClient):
if sys.version_info.major <= 2 or sys.version_info.minor < 6:
raise unittest.SkipTest("need mock features only available in 3.6+")

def test_async_failover(self, source_exc, target_exc, MockClient):
sampler = DWaveSampler(failover=True)

mocksolver = sampler.solver
edgelist = sampler.edgelist

# call once
ss = sampler.sample_ising({}, {})
# call once (async, no need to resolve)
sampler.sample_ising({}, {})

self.assertIs(mocksolver, sampler.solver) # still same solver

Expand All @@ -233,37 +237,26 @@ def test_failover_offline(self, MockClient):
+ sampler.solver.sample_qubo.call_count
+ sampler.solver.sample_bqm.call_count, 1)

# add a side-effect
sampler.solver.sample_ising.side_effect = SolverOfflineError
sampler.solver.sample_qubo.side_effect = SolverOfflineError
sampler.solver.sample_bqm.side_effect = SolverOfflineError
# simulate solver exception on sampleset resolve
fut = computation.Future(mocksolver, None)
fut._set_exception(source_exc)
sampler.solver.sample_bqm = mock.Mock()
sampler.solver.sample_bqm.return_value = fut

# verify failover signalled
with self.assertRaises(target_exc):
sampler.sample_ising({}, {}).resolve()

# and make sure get_solver makes a new mock solver
# make sure get_solver makes a new mock solver
sampler.client.get_solver.reset_mock(return_value=True)

ss = sampler.sample_ising({}, {})
# trigger failover
sampler.trigger_failover()

# verify failover
self.assertIsNot(mocksolver, sampler.solver) # new solver
self.assertIsNot(edgelist, sampler.edgelist) # also should be new

@mock.patch('dwave.system.samplers.dwave_sampler.Client')
def test_failover_notfound_noretry(self, MockClient):

sampler = DWaveSampler(failover=True, retry_interval=-1)

mocksolver = sampler.solver

# add a side-effect
sampler.solver.sample_ising.side_effect = SolverOfflineError
sampler.solver.sample_qubo.side_effect = SolverOfflineError
sampler.solver.sample_bqm.side_effect = SolverOfflineError

# and make sure get_solver makes a new mock solver
sampler.client.get_solver.side_effect = SolverNotFoundError

with self.assertRaises(SolverNotFoundError):
sampler.sample_ising({}, {})

def test_warnings_energy_range(self):
sampler = self.sampler

Expand Down