diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java index 5be92558ee0..7203d3a4945 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java @@ -18,6 +18,8 @@ import com.mongodb.lang.Nullable; +import java.util.concurrent.atomic.AtomicBoolean; + /** * See {@link AsyncRunnable} *

@@ -33,4 +35,28 @@ public interface AsyncFunction { * @param callback the callback */ void unsafeFinish(T value, SingleResultCallback callback); + + /** + * Must be invoked at end of async chain or when executing a callback handler supplied by the caller. + * + * @param callback the callback provided by the method the chain is used in. + */ + default void finish(final T value, final SingleResultCallback callback) { + final AtomicBoolean callbackInvoked = new AtomicBoolean(false); + try { + this.unsafeFinish(value, (v, e) -> { + if (!callbackInvoked.compareAndSet(false, true)) { + throw new AssertionError(String.format("Callback has been already completed. It could happen " + + "if code throws an exception after invoking an async method. Value: %s", v), e); + } + callback.onResult(v, e); + }); + } catch (Throwable t) { + if (!callbackInvoked.compareAndSet(false, true)) { + throw t; + } else { + callback.completeExceptionally(t); + } + } + } } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index fcf8d61387d..7a872ded718 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -170,7 +170,9 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) { return (c) -> { this.unsafeFinish((r, e) -> { if (e == null) { - runnable.unsafeFinish(c); + /* If 'runnable' is executed on a different thread from the one that executed the initial 'finish()', + then invoking 'finish()' within 'runnable' will catch and propagate any exceptions to 'c' (the callback). */ + runnable.finish(c); } else { c.completeExceptionally(e); } @@ -199,7 +201,7 @@ default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRu return; } if (matched) { - runnable.unsafeFinish(callback); + runnable.finish(callback); } else { callback.complete(callback); } @@ -216,7 +218,7 @@ default AsyncSupplier thenSupply(final AsyncSupplier supplier) { return (c) -> { this.unsafeFinish((r, e) -> { if (e == null) { - supplier.unsafeFinish(c); + supplier.finish(c); } else { c.completeExceptionally(e); } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index b7d24dd3df5..77c289c8723 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -18,6 +18,7 @@ import com.mongodb.lang.Nullable; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Predicate; @@ -54,18 +55,25 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback } /** - * Must be invoked at end of async chain. + * Must be invoked at end of async chain or when executing a callback handler supplied by the caller. + * + * @see #thenApply(AsyncFunction) + * @see #thenConsume(AsyncConsumer) + * @see #onErrorIf(Predicate, AsyncFunction) * @param callback the callback provided by the method the chain is used in */ default void finish(final SingleResultCallback callback) { - final boolean[] callbackInvoked = {false}; + final AtomicBoolean callbackInvoked = new AtomicBoolean(false); try { this.unsafeFinish((v, e) -> { - callbackInvoked[0] = true; + if (!callbackInvoked.compareAndSet(false, true)) { + throw new AssertionError(String.format("Callback has been already completed. It could happen " + + "if code throws an exception after invoking an async method. Value: %s", v), e); + } callback.onResult(v, e); }); } catch (Throwable t) { - if (callbackInvoked[0]) { + if (!callbackInvoked.compareAndSet(false, true)) { throw t; } else { callback.completeExceptionally(t); @@ -80,9 +88,9 @@ default void finish(final SingleResultCallback callback) { */ default AsyncSupplier thenApply(final AsyncFunction function) { return (c) -> { - this.unsafeFinish((v, e) -> { + this.finish((v, e) -> { if (e == null) { - function.unsafeFinish(v, c); + function.finish(v, c); } else { c.completeExceptionally(e); } @@ -99,7 +107,7 @@ default AsyncRunnable thenConsume(final AsyncConsumer consumer) { return (c) -> { this.unsafeFinish((v, e) -> { if (e == null) { - consumer.unsafeFinish(v, c); + consumer.finish(v, c); } else { c.completeExceptionally(e); } @@ -131,7 +139,7 @@ default AsyncSupplier onErrorIf( return; } if (errorMatched) { - errorFunction.unsafeFinish(e, callback); + errorFunction.finish(e, callback); } else { callback.completeExceptionally(e); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 218835f083e..7751bcba86f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -595,6 +595,7 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d return; } assertNotNull(responseBuffers); + T commandResult; try { updateSessionContext(sessionContext, responseBuffers); boolean commandOk = @@ -609,13 +610,14 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d } commandEventSender.sendSucceededEvent(responseBuffers); - T result1 = getCommandResult(decoder, responseBuffers, messageId); - callback.onResult(result1, null); + commandResult = getCommandResult(decoder, responseBuffers, messageId); } catch (Throwable localThrowable) { callback.onResult(null, localThrowable); + return; } finally { responseBuffers.close(); } + callback.onResult(commandResult, null); })); } }); diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java index d4858f3d973..b8f85289a0b 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java @@ -98,7 +98,14 @@ public void startHandshakeAsync(final InternalConnection internalConnection, callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t); } else { setSpeculativeAuthenticateResponse(helloResult); - callback.onResult(createInitializationDescription(helloResult, internalConnection, startTime), null); + InternalConnectionInitializationDescription initializationDescription; + try { + initializationDescription = createInitializationDescription(helloResult, internalConnection, startTime); + } catch (Throwable localThrowable) { + callback.onResult(null, localThrowable); + return; + } + callback.onResult(initializationDescription, null); } }); } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java similarity index 70% rename from driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java rename to driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java index b783b3de93b..16e4e978bf4 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java @@ -15,30 +15,16 @@ */ package com.mongodb.internal.async; -import com.mongodb.client.TestListener; import org.junit.jupiter.api.Test; -import org.opentest4j.AssertionFailedError; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.internal.async.AsyncRunnable.beginAsync; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; -final class AsyncFunctionsTest { - private final TestListener listener = new TestListener(); - private final InvocationTracker invocationTracker = new InvocationTracker(); - private boolean isTestingAbruptCompletion = false; +abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase { @Test void test1Method() { @@ -720,25 +706,6 @@ void testVariables() { }); } - @Test - void testInvalid() { - isTestingAbruptCompletion = false; - invocationTracker.isAsyncStep = true; - assertThrows(IllegalStateException.class, () -> { - beginAsync().thenRun(c -> { - async(3, c); - throw new IllegalStateException("must not cause second callback invocation"); - }).finish((v, e) -> {}); - }); - assertThrows(IllegalStateException.class, () -> { - beginAsync().thenRun(c -> { - async(3, c); - }).finish((v, e) -> { - throw new IllegalStateException("must not cause second callback invocation"); - }); - }); - } - @Test void testDerivation() { // Demonstrates the progression from nested async to the API. @@ -746,8 +713,8 @@ void testDerivation() { // Stand-ins for sync-async methods; these "happily" do not throw // exceptions, to avoid complicating this demo async code. Consumer happySync = (i) -> { - invocationTracker.getNextOption(1); - listener.add("affected-success-" + i); + getNextOption(1); + listenerAdd("affected-success-" + i); }; BiConsumer> happyAsync = (i, c) -> { happySync.accept(i); @@ -827,275 +794,4 @@ void testDerivation() { }); } - // invoked methods: - - private void plain(final int i) { - int cur = invocationTracker.getNextOption(2); - if (cur == 0) { - listener.add("plain-exception-" + i); - throw new RuntimeException("affected method exception-" + i); - } else { - listener.add("plain-success-" + i); - } - } - - private int plainReturns(final int i) { - int cur = invocationTracker.getNextOption(2); - if (cur == 0) { - listener.add("plain-exception-" + i); - throw new RuntimeException("affected method exception-" + i); - } else { - listener.add("plain-success-" + i); - return i; - } - } - - private boolean plainTest(final int i) { - int cur = invocationTracker.getNextOption(3); - if (cur == 0) { - listener.add("plain-exception-" + i); - throw new RuntimeException("affected method exception-" + i); - } else if (cur == 1) { - listener.add("plain-false-" + i); - return false; - } else { - listener.add("plain-true-" + i); - return true; - } - } - - private void sync(final int i) { - assertFalse(invocationTracker.isAsyncStep); - affected(i); - } - - - private Integer syncReturns(final int i) { - assertFalse(invocationTracker.isAsyncStep); - return affectedReturns(i); - } - - private void async(final int i, final SingleResultCallback callback) { - assertTrue(invocationTracker.isAsyncStep); - if (isTestingAbruptCompletion) { - affected(i); - callback.complete(callback); - - } else { - try { - affected(i); - callback.complete(callback); - } catch (Throwable t) { - callback.onResult(null, t); - } - } - } - - private void asyncReturns(final int i, final SingleResultCallback callback) { - assertTrue(invocationTracker.isAsyncStep); - if (isTestingAbruptCompletion) { - callback.complete(affectedReturns(i)); - } else { - try { - callback.complete(affectedReturns(i)); - } catch (Throwable t) { - callback.onResult(null, t); - } - } - } - - private void affected(final int i) { - int cur = invocationTracker.getNextOption(2); - if (cur == 0) { - listener.add("affected-exception-" + i); - throw new RuntimeException("exception-" + i); - } else { - listener.add("affected-success-" + i); - } - } - - private int affectedReturns(final int i) { - int cur = invocationTracker.getNextOption(2); - if (cur == 0) { - listener.add("affected-exception-" + i); - throw new RuntimeException("exception-" + i); - } else { - listener.add("affected-success-" + i); - return i; - } - } - - // assert methods: - - private void assertBehavesSameVariations(final int expectedVariations, final Runnable sync, - final Consumer> async) { - assertBehavesSameVariations(expectedVariations, - () -> { - sync.run(); - return null; - }, - (c) -> { - async.accept((v, e) -> c.onResult(v, e)); - }); - } - - private void assertBehavesSameVariations(final int expectedVariations, final Supplier sync, - final Consumer> async) { - // run the variation-trying code twice, with direct/indirect exceptions - for (int i = 0; i < 2; i++) { - isTestingAbruptCompletion = i != 0; - - // the variation-trying code: - invocationTracker.reset(); - do { - invocationTracker.startInitialStep(); - assertBehavesSame( - sync, - () -> invocationTracker.startMatchStep(), - async); - } while (invocationTracker.countDown()); - assertEquals(expectedVariations, invocationTracker.getVariationCount(), - "number of variations did not match"); - } - - } - - private void assertBehavesSame(final Supplier sync, final Runnable between, - final Consumer> async) { - - T expectedValue = null; - Throwable expectedException = null; - try { - expectedValue = sync.get(); - } catch (Throwable e) { - expectedException = e; - } - List expectedEvents = listener.getEventStrings(); - - listener.clear(); - between.run(); - - AtomicReference actualValue = new AtomicReference<>(); - AtomicReference actualException = new AtomicReference<>(); - AtomicBoolean wasCalled = new AtomicBoolean(false); - try { - async.accept((v, e) -> { - actualValue.set(v); - actualException.set(e); - if (wasCalled.get()) { - fail(); - } - wasCalled.set(true); - }); - } catch (Throwable e) { - fail("async threw instead of using callback"); - } - - // The following code can be used to debug variations: -// System.out.println("===VARIATION START"); -// System.out.println("sync: " + expectedEvents); -// System.out.println("callback called?: " + wasCalled.get()); -// System.out.println("value -- sync: " + expectedValue + " -- async: " + actualValue.get()); -// System.out.println("excep -- sync: " + expectedException + " -- async: " + actualException.get()); -// System.out.println("exception mode: " + (isTestingAbruptCompletion -// ? "exceptions thrown directly (abrupt completion)" : "exceptions into callbacks")); -// System.out.println("===VARIATION END"); - - // show assertion failures arising in async tests - if (actualException.get() != null && actualException.get() instanceof AssertionFailedError) { - throw (AssertionFailedError) actualException.get(); - } - - assertTrue(wasCalled.get(), "callback should have been called"); - assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched"); - assertEquals(expectedValue, actualValue.get()); - assertEquals(expectedException == null, actualException.get() == null, - "both or neither should have produced an exception"); - if (expectedException != null) { - assertEquals(expectedException.getMessage(), actualException.get().getMessage()); - assertEquals(expectedException.getClass(), actualException.get().getClass()); - } - - listener.clear(); - } - - /** - * Tracks invocations: allows testing of all variations of a method calls - */ - private static class InvocationTracker { - public static final int DEPTH_LIMIT = 50; - private final List invocationOptionSequence = new ArrayList<>(); - private boolean isAsyncStep; // async = matching, vs initial step = populating - private int currentInvocationIndex; - private int variationCount; - - public void reset() { - variationCount = 0; - } - - public void startInitialStep() { - variationCount++; - isAsyncStep = false; - currentInvocationIndex = -1; - } - - public int getNextOption(final int myOptionsSize) { - /* - This method creates (or gets) the next invocation's option. Each - invoker of this method has the "option" to behave in various ways, - usually just success (option 1) and exceptional failure (option 0), - though some callers might have more options. A sequence of method - outcomes (options) is one "variation". Tests automatically test - all possible variations (up to a limit, to prevent infinite loops). - - Methods generally have labels, to ensure that corresponding - sync/async methods are called in the right order, but these labels - are unrelated to the "variation" logic here. There are two "modes" - (whether completion is abrupt, or not), which are also unrelated. - */ - - currentInvocationIndex++; // which invocation result we are dealing with - - if (currentInvocationIndex >= invocationOptionSequence.size()) { - if (isAsyncStep) { - fail("result should have been pre-initialized: steps may not match"); - } - if (isWithinDepthLimit()) { - invocationOptionSequence.add(myOptionsSize - 1); - } else { - invocationOptionSequence.add(0); // choose "0" option, should always be an exception - } - } - return invocationOptionSequence.get(currentInvocationIndex); - } - - public void startMatchStep() { - isAsyncStep = true; - currentInvocationIndex = -1; - } - - private boolean countDown() { - while (!invocationOptionSequence.isEmpty()) { - int lastItemIndex = invocationOptionSequence.size() - 1; - int lastItem = invocationOptionSequence.get(lastItemIndex); - if (lastItem > 0) { - // count current digit down by 1, until 0 - invocationOptionSequence.set(lastItemIndex, lastItem - 1); - return true; - } else { - // current digit completed, remove (move left) - invocationOptionSequence.remove(lastItemIndex); - } - } - return false; - } - - public int getVariationCount() { - return variationCount; - } - - public boolean isWithinDepthLimit() { - return invocationOptionSequence.size() < DEPTH_LIMIT; - } - } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java new file mode 100644 index 00000000000..207e06b8a47 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java @@ -0,0 +1,373 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import com.mongodb.client.TestListener; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.opentest4j.AssertionFailedError; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public abstract class AsyncFunctionsTestBase { + + private final TestListener listener = new TestListener(); + private final InvocationTracker invocationTracker = new InvocationTracker(); + private boolean isTestingAbruptCompletion = false; + private ExecutorService asyncExecutor; + + void setIsTestingAbruptCompletion(final boolean b) { + isTestingAbruptCompletion = b; + } + + public void setAsyncStep(final boolean isAsyncStep) { + invocationTracker.isAsyncStep = isAsyncStep; + } + + public void getNextOption(final int i) { + invocationTracker.getNextOption(i); + } + + public void listenerAdd(final String s) { + listener.add(s); + } + + /** + * Create an executor service for async operations before each test. + * + * @return the executor service. + */ + public abstract ExecutorService createAsyncExecutor(); + + @BeforeEach + public void setUp() { + asyncExecutor = createAsyncExecutor(); + } + + @AfterEach + public void shutDown() { + asyncExecutor.shutdownNow(); + } + + void plain(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("plain-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else { + listener.add("plain-success-" + i); + } + } + + int plainReturns(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("plain-returns-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else { + listener.add("plain-returns-success-" + i); + return i; + } + } + + boolean plainTest(final int i) { + int cur = invocationTracker.getNextOption(3); + if (cur == 0) { + listener.add("plain-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else if (cur == 1) { + listener.add("plain-false-" + i); + return false; + } else { + listener.add("plain-true-" + i); + return true; + } + } + + void sync(final int i) { + assertFalse(invocationTracker.isAsyncStep); + affected(i); + } + + Integer syncReturns(final int i) { + assertFalse(invocationTracker.isAsyncStep); + return affectedReturns(i); + } + + + public void submit(final Runnable task) { + asyncExecutor.execute(task); + } + void async(final int i, final SingleResultCallback callback) { + assertTrue(invocationTracker.isAsyncStep); + if (isTestingAbruptCompletion) { + /* We should not test for abrupt completion in a separate thread. Once a callback is registered for an async operation, + the Async Framework does not handle exceptions thrown outside of callbacks by the executing thread. Such exception management + should be the responsibility of the thread conducting the asynchronous operations. */ + affected(i); + submit(() -> { + callback.complete(callback); + }); + } else { + submit(() -> { + try { + affected(i); + callback.complete(callback); + } catch (Throwable t) { + callback.onResult(null, t); + } + }); + } + } + + void asyncReturns(final int i, final SingleResultCallback callback) { + assertTrue(invocationTracker.isAsyncStep); + if (isTestingAbruptCompletion) { + int result = affectedReturns(i); + submit(() -> { + callback.complete(result); + }); + } else { + submit(() -> { + try { + callback.complete(affectedReturns(i)); + } catch (Throwable t) { + callback.onResult(null, t); + } + }); + } + } + + private void affected(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("affected-exception-" + i); + throw new RuntimeException("exception-" + i); + } else { + listener.add("affected-success-" + i); + } + } + + private int affectedReturns(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("affected-returns-exception-" + i); + throw new RuntimeException("exception-" + i); + } else { + listener.add("affected-returns-success-" + i); + return i; + } + } + + // assert methods: + + void assertBehavesSameVariations(final int expectedVariations, final Runnable sync, + final Consumer> async) { + assertBehavesSameVariations(expectedVariations, + () -> { + sync.run(); + return null; + }, + (c) -> { + async.accept((v, e) -> c.onResult(v, e)); + }); + } + + void assertBehavesSameVariations(final int expectedVariations, final Supplier sync, + final Consumer> async) { + // run the variation-trying code twice, with direct/indirect exceptions + for (int i = 0; i < 2; i++) { + isTestingAbruptCompletion = i != 0; + + // the variation-trying code: + invocationTracker.reset(); + do { + invocationTracker.startInitialStep(); + assertBehavesSame( + sync, + () -> invocationTracker.startMatchStep(), + async); + } while (invocationTracker.countDown()); + assertEquals(expectedVariations, invocationTracker.getVariationCount(), + "number of variations did not match"); + } + + } + + private void assertBehavesSame(final Supplier sync, final Runnable between, + final Consumer> async) { + + T expectedValue = null; + Throwable expectedException = null; + try { + expectedValue = sync.get(); + } catch (Throwable e) { + expectedException = e; + } + List expectedEvents = listener.getEventStrings(); + + listener.clear(); + between.run(); + + AtomicReference actualValue = new AtomicReference<>(); + AtomicReference actualException = new AtomicReference<>(); + CompletableFuture wasCalledFuture = new CompletableFuture<>(); + try { + async.accept((v, e) -> { + actualValue.set(v); + actualException.set(e); + if (wasCalledFuture.isDone()) { + fail(); + } + wasCalledFuture.complete(null); + }); + } catch (Throwable e) { + fail("async threw instead of using callback"); + } + + await(wasCalledFuture, "Callback should have been called"); + + // The following code can be used to debug variations: +// System.out.println("===VARIATION START"); +// System.out.println("sync: " + expectedEvents); +// System.out.println("callback called?: " + wasCalledFuture.isDone()); +// System.out.println("value -- sync: " + expectedValue + " -- async: " + actualValue.get()); +// System.out.println("excep -- sync: " + expectedException + " -- async: " + actualException.get()); +// System.out.println("exception mode: " + (isTestingAbruptCompletion +// ? "exceptions thrown directly (abrupt completion)" : "exceptions into callbacks")); +// System.out.println("===VARIATION END"); + + // show assertion failures arising in async tests + if (actualException.get() != null && actualException.get() instanceof AssertionFailedError) { + throw (AssertionFailedError) actualException.get(); + } + + assertTrue(wasCalledFuture.isDone(), "callback should have been called"); + assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched"); + assertEquals(expectedValue, actualValue.get()); + assertEquals(expectedException == null, actualException.get() == null, + "both or neither should have produced an exception"); + if (expectedException != null) { + assertEquals(expectedException.getMessage(), actualException.get().getMessage()); + assertEquals(expectedException.getClass(), actualException.get().getClass()); + } + + listener.clear(); + } + + protected T await(final CompletableFuture voidCompletableFuture, final String errorMessage) { + try { + return voidCompletableFuture.get(1, TimeUnit.MINUTES); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + throw new AssertionError(errorMessage); + } + } + + /** + * Tracks invocations: allows testing of all variations of a method calls + */ + static class InvocationTracker { + public static final int DEPTH_LIMIT = 50; + private final List invocationOptionSequence = new ArrayList<>(); + private boolean isAsyncStep; // async = matching, vs initial step = populating + private int currentInvocationIndex; + private int variationCount; + + public void reset() { + variationCount = 0; + } + + public void startInitialStep() { + variationCount++; + isAsyncStep = false; + currentInvocationIndex = -1; + } + + public int getNextOption(final int myOptionsSize) { + /* + This method creates (or gets) the next invocation's option. Each + invoker of this method has the "option" to behave in various ways, + usually just success (option 1) and exceptional failure (option 0), + though some callers might have more options. A sequence of method + outcomes (options) is one "variation". Tests automatically test + all possible variations (up to a limit, to prevent infinite loops). + + Methods generally have labels, to ensure that corresponding + sync/async methods are called in the right order, but these labels + are unrelated to the "variation" logic here. There are two "modes" + (whether completion is abrupt, or not), which are also unrelated. + */ + + currentInvocationIndex++; // which invocation result we are dealing with + + if (currentInvocationIndex >= invocationOptionSequence.size()) { + if (isAsyncStep) { + fail("result should have been pre-initialized: steps may not match"); + } + if (isWithinDepthLimit()) { + invocationOptionSequence.add(myOptionsSize - 1); + } else { + invocationOptionSequence.add(0); // choose "0" option, should always be an exception + } + } + return invocationOptionSequence.get(currentInvocationIndex); + } + + public void startMatchStep() { + isAsyncStep = true; + currentInvocationIndex = -1; + } + + private boolean countDown() { + while (!invocationOptionSequence.isEmpty()) { + int lastItemIndex = invocationOptionSequence.size() - 1; + int lastItem = invocationOptionSequence.get(lastItemIndex); + if (lastItem > 0) { + // count current digit down by 1, until 0 + invocationOptionSequence.set(lastItemIndex, lastItem - 1); + return true; + } else { + // current digit completed, remove (move left) + invocationOptionSequence.remove(lastItemIndex); + } + } + return false; + } + + public int getVariationCount() { + return variationCount; + } + + public boolean isWithinDepthLimit() { + return invocationOptionSequence.size() < DEPTH_LIMIT; + } + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java new file mode 100644 index 00000000000..04b9290af55 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/SameThreadAsyncFunctionsTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.AbstractExecutorService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@DisplayName("The same thread async functions") +public class SameThreadAsyncFunctionsTest extends AsyncFunctionsAbstractTest { + @Override + public ExecutorService createAsyncExecutor() { + return new SameThreadExecutorService(); + } + + @Test + void testInvalid() { + setIsTestingAbruptCompletion(false); + setAsyncStep(true); + IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation"); + + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + throw illegalStateException; + }).finish((v, e) -> { + assertNotEquals(e, illegalStateException); + }); + }); + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + }).finish((v, e) -> { + throw illegalStateException; + }); + }); + } + + private static class SameThreadExecutorService extends AbstractExecutorService { + @Override + public void execute(@NotNull final Runnable command) { + command.run(); + } + + @Override + public void shutdown() { + } + + @NotNull + @Override + public List shutdownNow() { + return Collections.emptyList(); + } + + @Override + public boolean isShutdown() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isTerminated() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean awaitTermination(final long timeout, @NotNull final TimeUnit unit) { + return true; + } + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java new file mode 100644 index 00000000000..401c4d2c18e --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/SeparateThreadAsyncFunctionsTest.java @@ -0,0 +1,118 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@DisplayName("Separate thread async functions") +public class SeparateThreadAsyncFunctionsTest extends AsyncFunctionsAbstractTest { + + private UncaughtExceptionHandler uncaughtExceptionHandler; + + @Override + public ExecutorService createAsyncExecutor() { + uncaughtExceptionHandler = new UncaughtExceptionHandler(); + return Executors.newFixedThreadPool(1, r -> { + Thread thread = new Thread(r); + thread.setUncaughtExceptionHandler(uncaughtExceptionHandler); + return thread; + }); + } + + /** + * This test covers the scenario where a callback is erroneously invoked after a callback had been completed. + * Such behavior is considered a bug and is not expected. An AssertionError should be thrown if an asynchronous invocation + * attempts to use a callback that has already been marked as completed. + */ + @Test + void shouldPropagateAssertionErrorIfCallbackHasBeenCompletedAfterAsyncInvocation() { + //given + setIsTestingAbruptCompletion(false); + setAsyncStep(true); + IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation"); + AtomicBoolean callbackInvoked = new AtomicBoolean(false); + + //when + beginAsync().thenRun(c -> { + async(3, c); + throw illegalStateException; + }).thenRun(c -> { + assertInvokedOnce(callbackInvoked); + c.complete(c); + }) + .finish((v, e) -> { + assertEquals(illegalStateException, e); + } + ); + + //then + Throwable exception = uncaughtExceptionHandler.getException(); + assertNotNull(exception); + assertEquals(AssertionError.class, exception.getClass()); + assertEquals("Callback has been already completed. It could happen " + + "if code throws an exception after invoking an async method. Value: null", exception.getMessage()); + } + + @Test + void shouldPropagateUnexpectedExceptionFromFinishCallback() { + //given + setIsTestingAbruptCompletion(false); + setAsyncStep(true); + IllegalStateException illegalStateException = new IllegalStateException("must not cause second callback invocation"); + + //when + beginAsync().thenRun(c -> { + async(3, c); + }).finish((v, e) -> { + throw illegalStateException; + }); + + //then + Throwable exception = uncaughtExceptionHandler.getException(); + assertNotNull(exception); + assertEquals(illegalStateException, exception); + } + + private static void assertInvokedOnce(final AtomicBoolean callbackInvoked1) { + assertTrue(callbackInvoked1.compareAndSet(false, true)); + } + + private final class UncaughtExceptionHandler implements Thread.UncaughtExceptionHandler { + + private final CompletableFuture completable = new CompletableFuture<>(); + + @Override + public void uncaughtException(final Thread t, final Throwable e) { + completable.complete(e); + } + + public Throwable getException() { + return await(completable, "No exception was thrown"); + } + } +}