@@ -714,7 +714,7 @@ private class DataLoaderMapEntrySubscriber implements Subscriber<Map.Entry<K, V>
714714 private final List <Object > callContexts ;
715715 private final List <CompletableFuture <V >> queuedFutures ;
716716 private final Map <K , Object > callContextByKey ;
717- private final Map <K , CompletableFuture <V >> queuedFutureByKey ;
717+ private final Map <K , List < CompletableFuture <V >>> queuedFuturesByKey ;
718718
719719 private final List <K > clearCacheKeys = new ArrayList <>();
720720 private final Map <K , V > completedValuesByKey = new HashMap <>();
@@ -733,13 +733,13 @@ private DataLoaderMapEntrySubscriber(
733733 this .queuedFutures = queuedFutures ;
734734
735735 this .callContextByKey = new HashMap <>();
736- this .queuedFutureByKey = new HashMap <>();
736+ this .queuedFuturesByKey = new HashMap <>();
737737 for (int idx = 0 ; idx < queuedFutures .size (); idx ++) {
738738 K key = keys .get (idx );
739739 Object callContext = callContexts .get (idx );
740740 CompletableFuture <V > queuedFuture = queuedFutures .get (idx );
741741 callContextByKey .put (key , callContext );
742- queuedFutureByKey . put (key , queuedFuture );
742+ queuedFuturesByKey . computeIfAbsent (key , k -> new ArrayList <>()). add ( queuedFuture );
743743 }
744744 }
745745
@@ -756,20 +756,20 @@ public void onNext(Map.Entry<K, V> entry) {
756756 V value = entry .getValue ();
757757
758758 Object callContext = callContextByKey .get (key );
759- CompletableFuture <V > future = queuedFutureByKey .get (key );
759+ List < CompletableFuture <V >> futures = queuedFuturesByKey .get (key );
760760 if (value instanceof Try ) {
761761 // we allow the batch loader to return a Try so we can better represent a computation
762762 // that might have worked or not.
763763 Try <V > tryValue = (Try <V >) value ;
764764 if (tryValue .isSuccess ()) {
765- future . complete (tryValue .get ());
765+ futures . forEach ( f -> f . complete (tryValue .get () ));
766766 } else {
767767 stats .incrementLoadErrorCount (new IncrementLoadErrorCountStatisticsContext <>(key , callContext ));
768- future . completeExceptionally (tryValue .getThrowable ());
768+ futures . forEach ( f -> f . completeExceptionally (tryValue .getThrowable () ));
769769 clearCacheKeys .add (key );
770770 }
771771 } else {
772- future . complete (value );
772+ futures . forEach ( f -> f . complete (value ) );
773773 }
774774
775775 completedValuesByKey .put (key , value );
@@ -801,9 +801,11 @@ public void onError(Throwable ex) {
801801 // Complete the futures for the remaining keys with the exception.
802802 for (int idx = 0 ; idx < queuedFutures .size (); idx ++) {
803803 K key = keys .get (idx );
804- CompletableFuture <V > future = queuedFutureByKey .get (key );
804+ List < CompletableFuture <V >> futures = queuedFuturesByKey .get (key );
805805 if (!completedValuesByKey .containsKey (key )) {
806- future .completeExceptionally (ex );
806+ for (CompletableFuture <V > future : futures ) {
807+ future .completeExceptionally (ex );
808+ }
807809 // clear any cached view of this key because they all failed
808810 dataLoader .clear (key );
809811 }
0 commit comments