Skip to content

Why aggregation methods return Cursor<Document> instead of Cursor<T>? #1098

@clarkmcc

Description

@clarkmcc

I'm fighting an issue right now where my find queries are returning Cursor<T> but my aggregate queries are returning Cursor<Document> and when I map the stream to T my trait implementations where S: Stream<Item=Result<T, E>> no longer work for the mapped cursor.

That's a long way of saying the aggregation cursor user experience feels a bit worse than the other methods. This issue was also reported here: https://www.mongodb.com/community/forums/t/get-specific-data-type-from-aggregation-instead-of-document/188241 but I never saw any response.

I assume there's a reason behind aggregation cursors not returning Cursor<T> but I'm not sure what it is, because with the following quick and dirty patch, I was able to run the test suite and get aggregates to return Cursor<T>. Is there some reason I'm missing why this can't be officially supported?

Index: src/action/aggregate.rs
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/action/aggregate.rs b/src/action/aggregate.rs
--- a/src/action/aggregate.rs	(revision 241fe3ddbdcb68409315ffb7dd2db151dbae13f4)
+++ b/src/action/aggregate.rs	(date 1715373649510)
@@ -1,3 +1,4 @@
+use std::marker::PhantomData;
 use std::time::Duration;
 
 use bson::Document;
@@ -27,12 +28,13 @@
     /// `await` will return d[`Result<Cursor<Document>>`] or d[`Result<SessionCursor<Document>>`] if
     /// a `ClientSession` is provided.
     #[deeplink]
-    pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
+    pub fn aggregate<T>(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate<ImplicitSession, T> {
         Aggregate {
             target: AggregateTargetRef::Database(self),
             pipeline: pipeline.into_iter().collect(),
             options: None,
             session: ImplicitSession,
+            _phantom: PhantomData,
         }
     }
 }
@@ -49,12 +51,13 @@
     /// `await` will return d[`Result<Cursor<Document>>`] or d[`Result<SessionCursor<Document>>`] if
     /// a [`ClientSession`] is provided.
     #[deeplink]
-    pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
+    pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate<ImplicitSession, T> {
         Aggregate {
             target: AggregateTargetRef::Collection(CollRef::new(self)),
             pipeline: pipeline.into_iter().collect(),
             options: None,
             session: ImplicitSession,
+            _phantom: PhantomData,
         }
     }
 }
@@ -95,11 +98,12 @@
 /// Run an aggregation operation.  Construct with [`Database::aggregate`] or
 /// [`Collection::aggregate`].
 #[must_use]
-pub struct Aggregate<'a, Session = ImplicitSession> {
+pub struct Aggregate<'a, Session = ImplicitSession, T = Document> {
     target: AggregateTargetRef<'a>,
     pipeline: Vec<Document>,
     options: Option<AggregateOptions>,
     session: Session,
+    _phantom: PhantomData<T>,
 }
 
 impl<'a, Session> Aggregate<'a, Session> {
@@ -119,7 +123,7 @@
     );
 }
 
-impl<'a> Aggregate<'a, ImplicitSession> {
+impl<'a, T> Aggregate<'a, ImplicitSession, T> {
     /// Use the provided session when running the operation.
     pub fn session(
         self,
@@ -130,15 +134,16 @@
             pipeline: self.pipeline,
             options: self.options,
             session: ExplicitSession(value.into()),
+            _phantom: PhantomData,
         }
     }
 }
 
 #[action_impl(sync = crate::sync::Cursor<Document>)]
-impl<'a> Action for Aggregate<'a, ImplicitSession> {
+impl<'a, T> Action for Aggregate<'a, ImplicitSession, T> {
     type Future = AggregateFuture;
 
-    async fn execute(mut self) -> Result<Cursor<Document>> {
+    async fn execute(mut self) -> Result<Cursor<T>> {
         resolve_options!(
             self.target,
             self.options,
@@ -156,10 +161,10 @@
 }
 
 #[action_impl(sync = crate::sync::SessionCursor<Document>)]
-impl<'a> Action for Aggregate<'a, ExplicitSession<'a>> {
+impl<'a, T> Action for Aggregate<'a, ExplicitSession<'a>, T> {
     type Future = AggregateSessionFuture;
 
-    async fn execute(mut self) -> Result<SessionCursor<Document>> {
+    async fn execute(mut self) -> Result<SessionCursor<T>> {
         resolve_read_concern_with_session!(self.target, self.options, Some(&mut *self.session.0))?;
         resolve_write_concern_with_session!(self.target, self.options, Some(&mut *self.session.0))?;
         resolve_selection_criteria_with_session!(
Index: src/test/util.rs
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/test/util.rs b/src/test/util.rs
--- a/src/test/util.rs	(revision 241fe3ddbdcb68409315ffb7dd2db151dbae13f4)
+++ b/src/test/util.rs	(date 1715375273252)
@@ -282,6 +282,22 @@
         self.get_coll(db_name, coll_name)
     }
 
+    pub(crate) async fn create_fresh_typed<T: Send + Sync>(
+        &self,
+        db_name: &str,
+        coll_name: &str,
+        options: impl Into<Option<CreateCollectionOptions>>,
+    ) -> Collection<T> {
+        self.drop_collection(db_name, coll_name).await;
+        self.database(db_name)
+            .create_collection(coll_name)
+            .with_options(options)
+            .await
+            .unwrap();
+
+        self.database(db_name).collection(coll_name)
+    }
+
     pub(crate) fn supports_fail_command(&self) -> bool {
         let version = if self.is_sharded() {
             ">= 4.1.5"
Index: src/test/db.rs
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/test/db.rs b/src/test/db.rs
--- a/src/test/db.rs	(revision 241fe3ddbdcb68409315ffb7dd2db151dbae13f4)
+++ b/src/test/db.rs	(date 1715375343608)
@@ -1,6 +1,9 @@
+use std::borrow::Borrow;
 use std::cmp::Ord;
 
 use futures::stream::TryStreamExt;
+use futures_util::StreamExt;
+use serde::{Deserialize, Serialize};
 
 use crate::{
     action::Action,
@@ -217,6 +220,23 @@
     assert!(coll3.id_index.is_none());
 }
 
+#[tokio::test]
+async fn db_aggregate_2() {
+    let client = TestClient::new().await;
+
+    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
+    struct Record {
+        name: String
+    }
+
+
+    let col = client.create_fresh_typed::<Record>("default", "default", None).await;
+    col.insert_one(Record { name: "a".to_string() }).await.unwrap();
+    let cursor = col.aggregate(vec![doc!{"$match": {}}]).await.unwrap();
+    let docs = cursor.try_collect::<Vec<_>>().await.unwrap();
+    println!("{:?}", docs)
+}
+
 #[tokio::test]
 async fn db_aggregate() {
     let client = TestClient::new().await;
@@ -254,7 +274,7 @@
         },
     ];
 
-    db.aggregate(pipeline)
+    db.aggregate::<Document>(pipeline)
         .await
         .expect("aggregate should succeed");
 }

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions