Skip to content

Commit

Permalink
Fix exception propagation in Async API methods (#1479) (#1485)
Browse files Browse the repository at this point in the history
- Resolve an issue where exceptions thrown during thenRun, thenSupply, and related operations in the asynchronous API were not properly propagated to the completion callback. This issue was addressed by replacing `unsafeFinish` with `finish`, ensuring that exceptions are caught and correctly passed to the completion callback when executed on different threads.

- Update existing Async API tests to ensure they simulate separate async thread execution.

- Modify the async callback to catch and handle exceptions locally. Exceptions are now directly processed and passed as an error argument to the callback function, avoiding propagation to the parent callback.

- Move `callback.onResult` outside the catch block to ensure it's not invoked twice when an exception occurs.

JAVA-5562
  • Loading branch information
vbabanin authored Aug 21, 2024
1 parent 39d1e9a commit adfab5f
Show file tree
Hide file tree
Showing 9 changed files with 647 additions and 321 deletions.
26 changes: 26 additions & 0 deletions driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import com.mongodb.lang.Nullable;

import java.util.concurrent.atomic.AtomicBoolean;

/**
* See {@link AsyncRunnable}
* <p>
Expand All @@ -33,4 +35,28 @@ public interface AsyncFunction<T, R> {
* @param callback the callback
*/
void unsafeFinish(T value, SingleResultCallback<R> callback);

/**
* Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
*
* @param callback the callback provided by the method the chain is used in.
*/
default void finish(final T value, final SingleResultCallback<R> callback) {
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
try {
this.unsafeFinish(value, (v, e) -> {
if (!callbackInvoked.compareAndSet(false, true)) {
throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
}
callback.onResult(v, e);
});
} catch (Throwable t) {
if (!callbackInvoked.compareAndSet(false, true)) {
throw t;
} else {
callback.completeExceptionally(t);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) {
return (c) -> {
this.unsafeFinish((r, e) -> {
if (e == null) {
runnable.unsafeFinish(c);
/* If 'runnable' is executed on a different thread from the one that executed the initial 'finish()',
then invoking 'finish()' within 'runnable' will catch and propagate any exceptions to 'c' (the callback). */
runnable.finish(c);
} else {
c.completeExceptionally(e);
}
Expand Down Expand Up @@ -199,7 +201,7 @@ default AsyncRunnable thenRunIf(final Supplier<Boolean> condition, final AsyncRu
return;
}
if (matched) {
runnable.unsafeFinish(callback);
runnable.finish(callback);
} else {
callback.complete(callback);
}
Expand All @@ -216,7 +218,7 @@ default <R> AsyncSupplier<R> thenSupply(final AsyncSupplier<R> supplier) {
return (c) -> {
this.unsafeFinish((r, e) -> {
if (e == null) {
supplier.unsafeFinish(c);
supplier.finish(c);
} else {
c.completeExceptionally(e);
}
Expand Down
24 changes: 16 additions & 8 deletions driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.mongodb.lang.Nullable;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;


Expand Down Expand Up @@ -54,18 +55,25 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback
}

/**
* Must be invoked at end of async chain.
* Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
*
* @see #thenApply(AsyncFunction)
* @see #thenConsume(AsyncConsumer)
* @see #onErrorIf(Predicate, AsyncFunction)
* @param callback the callback provided by the method the chain is used in
*/
default void finish(final SingleResultCallback<T> callback) {
final boolean[] callbackInvoked = {false};
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
try {
this.unsafeFinish((v, e) -> {
callbackInvoked[0] = true;
if (!callbackInvoked.compareAndSet(false, true)) {
throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
}
callback.onResult(v, e);
});
} catch (Throwable t) {
if (callbackInvoked[0]) {
if (!callbackInvoked.compareAndSet(false, true)) {
throw t;
} else {
callback.completeExceptionally(t);
Expand All @@ -80,9 +88,9 @@ default void finish(final SingleResultCallback<T> callback) {
*/
default <R> AsyncSupplier<R> thenApply(final AsyncFunction<T, R> function) {
return (c) -> {
this.unsafeFinish((v, e) -> {
this.finish((v, e) -> {
if (e == null) {
function.unsafeFinish(v, c);
function.finish(v, c);
} else {
c.completeExceptionally(e);
}
Expand All @@ -99,7 +107,7 @@ default AsyncRunnable thenConsume(final AsyncConsumer<T> consumer) {
return (c) -> {
this.unsafeFinish((v, e) -> {
if (e == null) {
consumer.unsafeFinish(v, c);
consumer.finish(v, c);
} else {
c.completeExceptionally(e);
}
Expand Down Expand Up @@ -131,7 +139,7 @@ default AsyncSupplier<T> onErrorIf(
return;
}
if (errorMatched) {
errorFunction.unsafeFinish(e, callback);
errorFunction.finish(e, callback);
} else {
callback.completeExceptionally(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ private <T> void sendCommandMessageAsync(final int messageId, final Decoder<T> d
return;
}
assertNotNull(responseBuffers);
T commandResult;
try {
updateSessionContext(sessionContext, responseBuffers);
boolean commandOk =
Expand All @@ -609,13 +610,14 @@ private <T> void sendCommandMessageAsync(final int messageId, final Decoder<T> d
}
commandEventSender.sendSucceededEvent(responseBuffers);

T result1 = getCommandResult(decoder, responseBuffers, messageId);
callback.onResult(result1, null);
commandResult = getCommandResult(decoder, responseBuffers, messageId);
} catch (Throwable localThrowable) {
callback.onResult(null, localThrowable);
return;
} finally {
responseBuffers.close();
}
callback.onResult(commandResult, null);
}));
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ public void startHandshakeAsync(final InternalConnection internalConnection,
callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t);
} else {
setSpeculativeAuthenticateResponse(helloResult);
callback.onResult(createInitializationDescription(helloResult, internalConnection, startTime), null);
InternalConnectionInitializationDescription initializationDescription;
try {
initializationDescription = createInitializationDescription(helloResult, internalConnection, startTime);
} catch (Throwable localThrowable) {
callback.onResult(null, localThrowable);
return;
}
callback.onResult(initializationDescription, null);
}
});
}
Expand Down
Loading

0 comments on commit adfab5f

Please sign in to comment.