diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java
index 5be92558ee0..7203d3a4945 100644
--- a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java
+++ b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java
@@ -18,6 +18,8 @@
import com.mongodb.lang.Nullable;
+import java.util.concurrent.atomic.AtomicBoolean;
+
/**
* See {@link AsyncRunnable}
*
@@ -33,4 +35,28 @@ public interface AsyncFunction {
* @param callback the callback
*/
void unsafeFinish(T value, SingleResultCallback callback);
+
+ /**
+ * Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
+ *
+ * @param callback the callback provided by the method the chain is used in.
+ */
+ default void finish(final T value, final SingleResultCallback callback) {
+ final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
+ try {
+ this.unsafeFinish(value, (v, e) -> {
+ if (!callbackInvoked.compareAndSet(false, true)) {
+ throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ + "if code throws an exception after invoking an async method. Value: %s", v), e);
+ }
+ callback.onResult(v, e);
+ });
+ } catch (Throwable t) {
+ if (!callbackInvoked.compareAndSet(false, true)) {
+ throw t;
+ } else {
+ callback.completeExceptionally(t);
+ }
+ }
+ }
}
diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java
index fcf8d61387d..7a872ded718 100644
--- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java
+++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java
@@ -170,7 +170,9 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) {
return (c) -> {
this.unsafeFinish((r, e) -> {
if (e == null) {
- runnable.unsafeFinish(c);
+ /* If 'runnable' is executed on a different thread from the one that executed the initial 'finish()',
+ then invoking 'finish()' within 'runnable' will catch and propagate any exceptions to 'c' (the callback). */
+ runnable.finish(c);
} else {
c.completeExceptionally(e);
}
@@ -199,7 +201,7 @@ default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRu
return;
}
if (matched) {
- runnable.unsafeFinish(callback);
+ runnable.finish(callback);
} else {
callback.complete(callback);
}
@@ -216,7 +218,7 @@ default AsyncSupplier thenSupply(final AsyncSupplier supplier) {
return (c) -> {
this.unsafeFinish((r, e) -> {
if (e == null) {
- supplier.unsafeFinish(c);
+ supplier.finish(c);
} else {
c.completeExceptionally(e);
}
diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java
index b7d24dd3df5..77c289c8723 100644
--- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java
+++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java
@@ -18,6 +18,7 @@
import com.mongodb.lang.Nullable;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;
@@ -54,18 +55,25 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback
}
/**
- * Must be invoked at end of async chain.
+ * Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
+ *
+ * @see #thenApply(AsyncFunction)
+ * @see #thenConsume(AsyncConsumer)
+ * @see #onErrorIf(Predicate, AsyncFunction)
* @param callback the callback provided by the method the chain is used in
*/
default void finish(final SingleResultCallback callback) {
- final boolean[] callbackInvoked = {false};
+ final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
try {
this.unsafeFinish((v, e) -> {
- callbackInvoked[0] = true;
+ if (!callbackInvoked.compareAndSet(false, true)) {
+ throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ + "if code throws an exception after invoking an async method. Value: %s", v), e);
+ }
callback.onResult(v, e);
});
} catch (Throwable t) {
- if (callbackInvoked[0]) {
+ if (!callbackInvoked.compareAndSet(false, true)) {
throw t;
} else {
callback.completeExceptionally(t);
@@ -80,9 +88,9 @@ default void finish(final SingleResultCallback callback) {
*/
default AsyncSupplier thenApply(final AsyncFunction function) {
return (c) -> {
- this.unsafeFinish((v, e) -> {
+ this.finish((v, e) -> {
if (e == null) {
- function.unsafeFinish(v, c);
+ function.finish(v, c);
} else {
c.completeExceptionally(e);
}
@@ -99,7 +107,7 @@ default AsyncRunnable thenConsume(final AsyncConsumer consumer) {
return (c) -> {
this.unsafeFinish((v, e) -> {
if (e == null) {
- consumer.unsafeFinish(v, c);
+ consumer.finish(v, c);
} else {
c.completeExceptionally(e);
}
@@ -131,7 +139,7 @@ default AsyncSupplier onErrorIf(
return;
}
if (errorMatched) {
- errorFunction.unsafeFinish(e, callback);
+ errorFunction.finish(e, callback);
} else {
callback.completeExceptionally(e);
}
diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java
index 218835f083e..7751bcba86f 100644
--- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java
+++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java
@@ -595,6 +595,7 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d
return;
}
assertNotNull(responseBuffers);
+ T commandResult;
try {
updateSessionContext(sessionContext, responseBuffers);
boolean commandOk =
@@ -609,13 +610,14 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d
}
commandEventSender.sendSucceededEvent(responseBuffers);
- T result1 = getCommandResult(decoder, responseBuffers, messageId);
- callback.onResult(result1, null);
+ commandResult = getCommandResult(decoder, responseBuffers, messageId);
} catch (Throwable localThrowable) {
callback.onResult(null, localThrowable);
+ return;
} finally {
responseBuffers.close();
}
+ callback.onResult(commandResult, null);
}));
}
});
diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java
index d4858f3d973..b8f85289a0b 100644
--- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java
+++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java
@@ -98,7 +98,14 @@ public void startHandshakeAsync(final InternalConnection internalConnection,
callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t);
} else {
setSpeculativeAuthenticateResponse(helloResult);
- callback.onResult(createInitializationDescription(helloResult, internalConnection, startTime), null);
+ InternalConnectionInitializationDescription initializationDescription;
+ try {
+ initializationDescription = createInitializationDescription(helloResult, internalConnection, startTime);
+ } catch (Throwable localThrowable) {
+ callback.onResult(null, localThrowable);
+ return;
+ }
+ callback.onResult(initializationDescription, null);
}
});
}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java
similarity index 70%
rename from driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java
rename to driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java
index b783b3de93b..16e4e978bf4 100644
--- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java
+++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java
@@ -15,30 +15,16 @@
*/
package com.mongodb.internal.async;
-import com.mongodb.client.TestListener;
import org.junit.jupiter.api.Test;
-import org.opentest4j.AssertionFailedError;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-import static org.junit.jupiter.api.Assertions.fail;
-final class AsyncFunctionsTest {
- private final TestListener listener = new TestListener();
- private final InvocationTracker invocationTracker = new InvocationTracker();
- private boolean isTestingAbruptCompletion = false;
+abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase {
@Test
void test1Method() {
@@ -720,25 +706,6 @@ void testVariables() {
});
}
- @Test
- void testInvalid() {
- isTestingAbruptCompletion = false;
- invocationTracker.isAsyncStep = true;
- assertThrows(IllegalStateException.class, () -> {
- beginAsync().thenRun(c -> {
- async(3, c);
- throw new IllegalStateException("must not cause second callback invocation");
- }).finish((v, e) -> {});
- });
- assertThrows(IllegalStateException.class, () -> {
- beginAsync().thenRun(c -> {
- async(3, c);
- }).finish((v, e) -> {
- throw new IllegalStateException("must not cause second callback invocation");
- });
- });
- }
-
@Test
void testDerivation() {
// Demonstrates the progression from nested async to the API.
@@ -746,8 +713,8 @@ void testDerivation() {
// Stand-ins for sync-async methods; these "happily" do not throw
// exceptions, to avoid complicating this demo async code.
Consumer happySync = (i) -> {
- invocationTracker.getNextOption(1);
- listener.add("affected-success-" + i);
+ getNextOption(1);
+ listenerAdd("affected-success-" + i);
};
BiConsumer> happyAsync = (i, c) -> {
happySync.accept(i);
@@ -827,275 +794,4 @@ void testDerivation() {
});
}
- // invoked methods:
-
- private void plain(final int i) {
- int cur = invocationTracker.getNextOption(2);
- if (cur == 0) {
- listener.add("plain-exception-" + i);
- throw new RuntimeException("affected method exception-" + i);
- } else {
- listener.add("plain-success-" + i);
- }
- }
-
- private int plainReturns(final int i) {
- int cur = invocationTracker.getNextOption(2);
- if (cur == 0) {
- listener.add("plain-exception-" + i);
- throw new RuntimeException("affected method exception-" + i);
- } else {
- listener.add("plain-success-" + i);
- return i;
- }
- }
-
- private boolean plainTest(final int i) {
- int cur = invocationTracker.getNextOption(3);
- if (cur == 0) {
- listener.add("plain-exception-" + i);
- throw new RuntimeException("affected method exception-" + i);
- } else if (cur == 1) {
- listener.add("plain-false-" + i);
- return false;
- } else {
- listener.add("plain-true-" + i);
- return true;
- }
- }
-
- private void sync(final int i) {
- assertFalse(invocationTracker.isAsyncStep);
- affected(i);
- }
-
-
- private Integer syncReturns(final int i) {
- assertFalse(invocationTracker.isAsyncStep);
- return affectedReturns(i);
- }
-
- private void async(final int i, final SingleResultCallback callback) {
- assertTrue(invocationTracker.isAsyncStep);
- if (isTestingAbruptCompletion) {
- affected(i);
- callback.complete(callback);
-
- } else {
- try {
- affected(i);
- callback.complete(callback);
- } catch (Throwable t) {
- callback.onResult(null, t);
- }
- }
- }
-
- private void asyncReturns(final int i, final SingleResultCallback callback) {
- assertTrue(invocationTracker.isAsyncStep);
- if (isTestingAbruptCompletion) {
- callback.complete(affectedReturns(i));
- } else {
- try {
- callback.complete(affectedReturns(i));
- } catch (Throwable t) {
- callback.onResult(null, t);
- }
- }
- }
-
- private void affected(final int i) {
- int cur = invocationTracker.getNextOption(2);
- if (cur == 0) {
- listener.add("affected-exception-" + i);
- throw new RuntimeException("exception-" + i);
- } else {
- listener.add("affected-success-" + i);
- }
- }
-
- private int affectedReturns(final int i) {
- int cur = invocationTracker.getNextOption(2);
- if (cur == 0) {
- listener.add("affected-exception-" + i);
- throw new RuntimeException("exception-" + i);
- } else {
- listener.add("affected-success-" + i);
- return i;
- }
- }
-
- // assert methods:
-
- private void assertBehavesSameVariations(final int expectedVariations, final Runnable sync,
- final Consumer> async) {
- assertBehavesSameVariations(expectedVariations,
- () -> {
- sync.run();
- return null;
- },
- (c) -> {
- async.accept((v, e) -> c.onResult(v, e));
- });
- }
-
- private void assertBehavesSameVariations(final int expectedVariations, final Supplier sync,
- final Consumer> async) {
- // run the variation-trying code twice, with direct/indirect exceptions
- for (int i = 0; i < 2; i++) {
- isTestingAbruptCompletion = i != 0;
-
- // the variation-trying code:
- invocationTracker.reset();
- do {
- invocationTracker.startInitialStep();
- assertBehavesSame(
- sync,
- () -> invocationTracker.startMatchStep(),
- async);
- } while (invocationTracker.countDown());
- assertEquals(expectedVariations, invocationTracker.getVariationCount(),
- "number of variations did not match");
- }
-
- }
-
- private void assertBehavesSame(final Supplier sync, final Runnable between,
- final Consumer> async) {
-
- T expectedValue = null;
- Throwable expectedException = null;
- try {
- expectedValue = sync.get();
- } catch (Throwable e) {
- expectedException = e;
- }
- List expectedEvents = listener.getEventStrings();
-
- listener.clear();
- between.run();
-
- AtomicReference actualValue = new AtomicReference<>();
- AtomicReference actualException = new AtomicReference<>();
- AtomicBoolean wasCalled = new AtomicBoolean(false);
- try {
- async.accept((v, e) -> {
- actualValue.set(v);
- actualException.set(e);
- if (wasCalled.get()) {
- fail();
- }
- wasCalled.set(true);
- });
- } catch (Throwable e) {
- fail("async threw instead of using callback");
- }
-
- // The following code can be used to debug variations:
-// System.out.println("===VARIATION START");
-// System.out.println("sync: " + expectedEvents);
-// System.out.println("callback called?: " + wasCalled.get());
-// System.out.println("value -- sync: " + expectedValue + " -- async: " + actualValue.get());
-// System.out.println("excep -- sync: " + expectedException + " -- async: " + actualException.get());
-// System.out.println("exception mode: " + (isTestingAbruptCompletion
-// ? "exceptions thrown directly (abrupt completion)" : "exceptions into callbacks"));
-// System.out.println("===VARIATION END");
-
- // show assertion failures arising in async tests
- if (actualException.get() != null && actualException.get() instanceof AssertionFailedError) {
- throw (AssertionFailedError) actualException.get();
- }
-
- assertTrue(wasCalled.get(), "callback should have been called");
- assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
- assertEquals(expectedValue, actualValue.get());
- assertEquals(expectedException == null, actualException.get() == null,
- "both or neither should have produced an exception");
- if (expectedException != null) {
- assertEquals(expectedException.getMessage(), actualException.get().getMessage());
- assertEquals(expectedException.getClass(), actualException.get().getClass());
- }
-
- listener.clear();
- }
-
- /**
- * Tracks invocations: allows testing of all variations of a method calls
- */
- private static class InvocationTracker {
- public static final int DEPTH_LIMIT = 50;
- private final List invocationOptionSequence = new ArrayList<>();
- private boolean isAsyncStep; // async = matching, vs initial step = populating
- private int currentInvocationIndex;
- private int variationCount;
-
- public void reset() {
- variationCount = 0;
- }
-
- public void startInitialStep() {
- variationCount++;
- isAsyncStep = false;
- currentInvocationIndex = -1;
- }
-
- public int getNextOption(final int myOptionsSize) {
- /*
- This method creates (or gets) the next invocation's option. Each
- invoker of this method has the "option" to behave in various ways,
- usually just success (option 1) and exceptional failure (option 0),
- though some callers might have more options. A sequence of method
- outcomes (options) is one "variation". Tests automatically test
- all possible variations (up to a limit, to prevent infinite loops).
-
- Methods generally have labels, to ensure that corresponding
- sync/async methods are called in the right order, but these labels
- are unrelated to the "variation" logic here. There are two "modes"
- (whether completion is abrupt, or not), which are also unrelated.
- */
-
- currentInvocationIndex++; // which invocation result we are dealing with
-
- if (currentInvocationIndex >= invocationOptionSequence.size()) {
- if (isAsyncStep) {
- fail("result should have been pre-initialized: steps may not match");
- }
- if (isWithinDepthLimit()) {
- invocationOptionSequence.add(myOptionsSize - 1);
- } else {
- invocationOptionSequence.add(0); // choose "0" option, should always be an exception
- }
- }
- return invocationOptionSequence.get(currentInvocationIndex);
- }
-
- public void startMatchStep() {
- isAsyncStep = true;
- currentInvocationIndex = -1;
- }
-
- private boolean countDown() {
- while (!invocationOptionSequence.isEmpty()) {
- int lastItemIndex = invocationOptionSequence.size() - 1;
- int lastItem = invocationOptionSequence.get(lastItemIndex);
- if (lastItem > 0) {
- // count current digit down by 1, until 0
- invocationOptionSequence.set(lastItemIndex, lastItem - 1);
- return true;
- } else {
- // current digit completed, remove (move left)
- invocationOptionSequence.remove(lastItemIndex);
- }
- }
- return false;
- }
-
- public int getVariationCount() {
- return variationCount;
- }
-
- public boolean isWithinDepthLimit() {
- return invocationOptionSequence.size() < DEPTH_LIMIT;
- }
- }
}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java
new file mode 100644
index 00000000000..207e06b8a47
--- /dev/null
+++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java
@@ -0,0 +1,373 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.internal.async;
+
+import com.mongodb.client.TestListener;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.opentest4j.AssertionFailedError;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
+
+public abstract class AsyncFunctionsTestBase {
+
+ private final TestListener listener = new TestListener();
+ private final InvocationTracker invocationTracker = new InvocationTracker();
+ private boolean isTestingAbruptCompletion = false;
+ private ExecutorService asyncExecutor;
+
+ void setIsTestingAbruptCompletion(final boolean b) {
+ isTestingAbruptCompletion = b;
+ }
+
+ public void setAsyncStep(final boolean isAsyncStep) {
+ invocationTracker.isAsyncStep = isAsyncStep;
+ }
+
+ public void getNextOption(final int i) {
+ invocationTracker.getNextOption(i);
+ }
+
+ public void listenerAdd(final String s) {
+ listener.add(s);
+ }
+
+ /**
+ * Create an executor service for async operations before each test.
+ *
+ * @return the executor service.
+ */
+ public abstract ExecutorService createAsyncExecutor();
+
+ @BeforeEach
+ public void setUp() {
+ asyncExecutor = createAsyncExecutor();
+ }
+
+ @AfterEach
+ public void shutDown() {
+ asyncExecutor.shutdownNow();
+ }
+
+ void plain(final int i) {
+ int cur = invocationTracker.getNextOption(2);
+ if (cur == 0) {
+ listener.add("plain-exception-" + i);
+ throw new RuntimeException("affected method exception-" + i);
+ } else {
+ listener.add("plain-success-" + i);
+ }
+ }
+
+ int plainReturns(final int i) {
+ int cur = invocationTracker.getNextOption(2);
+ if (cur == 0) {
+ listener.add("plain-returns-exception-" + i);
+ throw new RuntimeException("affected method exception-" + i);
+ } else {
+ listener.add("plain-returns-success-" + i);
+ return i;
+ }
+ }
+
+ boolean plainTest(final int i) {
+ int cur = invocationTracker.getNextOption(3);
+ if (cur == 0) {
+ listener.add("plain-exception-" + i);
+ throw new RuntimeException("affected method exception-" + i);
+ } else if (cur == 1) {
+ listener.add("plain-false-" + i);
+ return false;
+ } else {
+ listener.add("plain-true-" + i);
+ return true;
+ }
+ }
+
+ void sync(final int i) {
+ assertFalse(invocationTracker.isAsyncStep);
+ affected(i);
+ }
+
+ Integer syncReturns(final int i) {
+ assertFalse(invocationTracker.isAsyncStep);
+ return affectedReturns(i);
+ }
+
+
+ public void submit(final Runnable task) {
+ asyncExecutor.execute(task);
+ }
+ void async(final int i, final SingleResultCallback callback) {
+ assertTrue(invocationTracker.isAsyncStep);
+ if (isTestingAbruptCompletion) {
+ /* We should not test for abrupt completion in a separate thread. Once a callback is registered for an async operation,
+ the Async Framework does not handle exceptions thrown outside of callbacks by the executing thread. Such exception management
+ should be the responsibility of the thread conducting the asynchronous operations. */
+ affected(i);
+ submit(() -> {
+ callback.complete(callback);
+ });
+ } else {
+ submit(() -> {
+ try {
+ affected(i);
+ callback.complete(callback);
+ } catch (Throwable t) {
+ callback.onResult(null, t);
+ }
+ });
+ }
+ }
+
+ void asyncReturns(final int i, final SingleResultCallback callback) {
+ assertTrue(invocationTracker.isAsyncStep);
+ if (isTestingAbruptCompletion) {
+ int result = affectedReturns(i);
+ submit(() -> {
+ callback.complete(result);
+ });
+ } else {
+ submit(() -> {
+ try {
+ callback.complete(affectedReturns(i));
+ } catch (Throwable t) {
+ callback.onResult(null, t);
+ }
+ });
+ }
+ }
+
+ private void affected(final int i) {
+ int cur = invocationTracker.getNextOption(2);
+ if (cur == 0) {
+ listener.add("affected-exception-" + i);
+ throw new RuntimeException("exception-" + i);
+ } else {
+ listener.add("affected-success-" + i);
+ }
+ }
+
+ private int affectedReturns(final int i) {
+ int cur = invocationTracker.getNextOption(2);
+ if (cur == 0) {
+ listener.add("affected-returns-exception-" + i);
+ throw new RuntimeException("exception-" + i);
+ } else {
+ listener.add("affected-returns-success-" + i);
+ return i;
+ }
+ }
+
+ // assert methods:
+
+ void assertBehavesSameVariations(final int expectedVariations, final Runnable sync,
+ final Consumer> async) {
+ assertBehavesSameVariations(expectedVariations,
+ () -> {
+ sync.run();
+ return null;
+ },
+ (c) -> {
+ async.accept((v, e) -> c.onResult(v, e));
+ });
+ }
+
+ void assertBehavesSameVariations(final int expectedVariations, final Supplier sync,
+ final Consumer> async) {
+ // run the variation-trying code twice, with direct/indirect exceptions
+ for (int i = 0; i < 2; i++) {
+ isTestingAbruptCompletion = i != 0;
+
+ // the variation-trying code:
+ invocationTracker.reset();
+ do {
+ invocationTracker.startInitialStep();
+ assertBehavesSame(
+ sync,
+ () -> invocationTracker.startMatchStep(),
+ async);
+ } while (invocationTracker.countDown());
+ assertEquals(expectedVariations, invocationTracker.getVariationCount(),
+ "number of variations did not match");
+ }
+
+ }
+
+ private void assertBehavesSame(final Supplier sync, final Runnable between,
+ final Consumer> async) {
+
+ T expectedValue = null;
+ Throwable expectedException = null;
+ try {
+ expectedValue = sync.get();
+ } catch (Throwable e) {
+ expectedException = e;
+ }
+ List expectedEvents = listener.getEventStrings();
+
+ listener.clear();
+ between.run();
+
+ AtomicReference actualValue = new AtomicReference<>();
+ AtomicReference actualException = new AtomicReference<>();
+ CompletableFuture wasCalledFuture = new CompletableFuture<>();
+ try {
+ async.accept((v, e) -> {
+ actualValue.set(v);
+ actualException.set(e);
+ if (wasCalledFuture.isDone()) {
+ fail();
+ }
+ wasCalledFuture.complete(null);
+ });
+ } catch (Throwable e) {
+ fail("async threw instead of using callback");
+ }
+
+ await(wasCalledFuture, "Callback should have been called");
+
+ // The following code can be used to debug variations:
+// System.out.println("===VARIATION START");
+// System.out.println("sync: " + expectedEvents);
+// System.out.println("callback called?: " + wasCalledFuture.isDone());
+// System.out.println("value -- sync: " + expectedValue + " -- async: " + actualValue.get());
+// System.out.println("excep -- sync: " + expectedException + " -- async: " + actualException.get());
+// System.out.println("exception mode: " + (isTestingAbruptCompletion
+// ? "exceptions thrown directly (abrupt completion)" : "exceptions into callbacks"));
+// System.out.println("===VARIATION END");
+
+ // show assertion failures arising in async tests
+ if (actualException.get() != null && actualException.get() instanceof AssertionFailedError) {
+ throw (AssertionFailedError) actualException.get();
+ }
+
+ assertTrue(wasCalledFuture.isDone(), "callback should have been called");
+ assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
+ assertEquals(expectedValue, actualValue.get());
+ assertEquals(expectedException == null, actualException.get() == null,
+ "both or neither should have produced an exception");
+ if (expectedException != null) {
+ assertEquals(expectedException.getMessage(), actualException.get().getMessage());
+ assertEquals(expectedException.getClass(), actualException.get().getClass());
+ }
+
+ listener.clear();
+ }
+
+ protected T await(final CompletableFuture voidCompletableFuture, final String errorMessage) {
+ try {
+ return voidCompletableFuture.get(1, TimeUnit.MINUTES);
+ } catch (InterruptedException | ExecutionException | TimeoutException e) {
+ throw new AssertionError(errorMessage);
+ }
+ }
+
+ /**
+ * Tracks invocations: allows testing of all variations of a method calls
+ */
+ static class InvocationTracker {
+ public static final int DEPTH_LIMIT = 50;
+ private final List invocationOptionSequence = new ArrayList<>();
+ private boolean isAsyncStep; // async = matching, vs initial step = populating
+ private int currentInvocationIndex;
+ private int variationCount;
+
+ public void reset() {
+ variationCount = 0;
+ }
+
+ public void startInitialStep() {
+ variationCount++;
+ isAsyncStep = false;
+ currentInvocationIndex = -1;
+ }
+
+ public int getNextOption(final int myOptionsSize) {
+ /*
+ This method creates (or gets) the next invocation's option. Each
+ invoker of this method has the "option" to behave in various ways,
+ usually just success (option 1) and exceptional failure (option 0),
+ though some callers might have more options. A sequence of method
+ outcomes (options) is one "variation". Tests automatically test
+ all possible variations (up to a limit, to prevent infinite loops).
+
+ Methods generally have labels, to ensure that corresponding
+ sync/async methods are called in the right order, but these labels
+ are unrelated to the "variation" logic here. There are two "modes"
+ (whether completion is abrupt, or not), which are also unrelated.
+ */
+
+ currentInvocationIndex++; // which invocation result we are dealing with
+
+ if (currentInvocationIndex >= invocationOptionSequence.size()) {
+ if (isAsyncStep) {
+ fail("result should have been pre-initialized: steps may not match");
+ }
+ if (isWithinDepthLimit()) {
+ invocationOptionSequence.add(myOptionsSize - 1);
+ } else {
+ invocationOptionSequence.add(0); // choose "0" option, should always be an exception
+ }
+ }
+ return invocationOptionSequence.get(currentInvocationIndex);
+ }
+
+ public void startMatchStep() {
+ isAsyncStep = true;
+ currentInvocationIndex = -1;
+ }
+
+ private boolean countDown() {
+ while (!invocationOptionSequence.isEmpty()) {
+ int lastItemIndex = invocationOptionSequence.size() - 1;
+ int lastItem = invocationOptionSequence.get(lastItemIndex);
+ if (lastItem > 0) {
+ // count current digit down by 1, until 0
+ invocationOptionSequence.set(lastItemIndex, lastItem - 1);
+ return true;
+ } else {
+ // current digit completed, remove (move left)
+ invocationOptionSequence.remove(lastItemIndex);
+ }
+ }
+ return false;
+ }
+
+ public int getVariationCount() {
+ return variationCount;
+ }
+
+ public boolean isWithinDepthLimit() {
+ return invocationOptionSequence.size() < DEPTH_LIMIT;
+ }
+ }
+}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java
new file mode 100644
index 00000000000..04b9290af55
--- /dev/null
+++ b/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.internal.async;
+
+import org.jetbrains.annotations.NotNull;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.AbstractExecutorService;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+@DisplayName("The same thread async functions")
+public class SameThreadAsyncFunctionsTest extends AsyncFunctionsAbstractTest {
+ @Override
+ public ExecutorService createAsyncExecutor() {
+ return new SameThreadExecutorService();
+ }
+
+ @Test
+ void testInvalid() {
+ setIsTestingAbruptCompletion(false);
+ setAsyncStep(true);
+ IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation");
+
+ assertThrows(IllegalStateException.class, () -> {
+ beginAsync().thenRun(c -> {
+ async(3, c);
+ throw illegalStateException;
+ }).finish((v, e) -> {
+ assertNotEquals(e, illegalStateException);
+ });
+ });
+ assertThrows(IllegalStateException.class, () -> {
+ beginAsync().thenRun(c -> {
+ async(3, c);
+ }).finish((v, e) -> {
+ throw illegalStateException;
+ });
+ });
+ }
+
+ private static class SameThreadExecutorService extends AbstractExecutorService {
+ @Override
+ public void execute(@NotNull final Runnable command) {
+ command.run();
+ }
+
+ @Override
+ public void shutdown() {
+ }
+
+ @NotNull
+ @Override
+ public List shutdownNow() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public boolean isShutdown() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean isTerminated() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean awaitTermination(final long timeout, @NotNull final TimeUnit unit) {
+ return true;
+ }
+ }
+}
diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java
new file mode 100644
index 00000000000..401c4d2c18e
--- /dev/null
+++ b/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java
@@ -0,0 +1,118 @@
+/*
+ * Copyright 2008-present MongoDB, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.mongodb.internal.async;
+
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+@DisplayName("Separate thread async functions")
+public class SeparateThreadAsyncFunctionsTest extends AsyncFunctionsAbstractTest {
+
+ private UncaughtExceptionHandler uncaughtExceptionHandler;
+
+ @Override
+ public ExecutorService createAsyncExecutor() {
+ uncaughtExceptionHandler = new UncaughtExceptionHandler();
+ return Executors.newFixedThreadPool(1, r -> {
+ Thread thread = new Thread(r);
+ thread.setUncaughtExceptionHandler(uncaughtExceptionHandler);
+ return thread;
+ });
+ }
+
+ /**
+ * This test covers the scenario where a callback is erroneously invoked after a callback had been completed.
+ * Such behavior is considered a bug and is not expected. An AssertionError should be thrown if an asynchronous invocation
+ * attempts to use a callback that has already been marked as completed.
+ */
+ @Test
+ void shouldPropagateAssertionErrorIfCallbackHasBeenCompletedAfterAsyncInvocation() {
+ //given
+ setIsTestingAbruptCompletion(false);
+ setAsyncStep(true);
+ IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation");
+ AtomicBoolean callbackInvoked = new AtomicBoolean(false);
+
+ //when
+ beginAsync().thenRun(c -> {
+ async(3, c);
+ throw illegalStateException;
+ }).thenRun(c -> {
+ assertInvokedOnce(callbackInvoked);
+ c.complete(c);
+ })
+ .finish((v, e) -> {
+ assertEquals(illegalStateException, e);
+ }
+ );
+
+ //then
+ Throwable exception = uncaughtExceptionHandler.getException();
+ assertNotNull(exception);
+ assertEquals(AssertionError.class, exception.getClass());
+ assertEquals("Callback has been already completed. It could happen "
+ + "if code throws an exception after invoking an async method. Value: null", exception.getMessage());
+ }
+
+ @Test
+ void shouldPropagateUnexpectedExceptionFromFinishCallback() {
+ //given
+ setIsTestingAbruptCompletion(false);
+ setAsyncStep(true);
+ IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation");
+
+ //when
+ beginAsync().thenRun(c -> {
+ async(3, c);
+ }).finish((v, e) -> {
+ throw illegalStateException;
+ });
+
+ //then
+ Throwable exception = uncaughtExceptionHandler.getException();
+ assertNotNull(exception);
+ assertEquals(illegalStateException, exception);
+ }
+
+ private static void assertInvokedOnce(final AtomicBoolean callbackInvoked1) {
+ assertTrue(callbackInvoked1.compareAndSet(false, true));
+ }
+
+ private final class UncaughtExceptionHandler implements Thread.UncaughtExceptionHandler {
+
+ private final CompletableFuture completable = new CompletableFuture<>();
+
+ @Override
+ public void uncaughtException(final Thread t, final Throwable e) {
+ completable.complete(e);
+ }
+
+ public Throwable getException() {
+ return await(completable, "No exception was thrown");
+ }
+ }
+}