diff --git a/Cargo.toml b/Cargo.toml index 70eadf76db..b0d3fcd1b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,7 @@ itertools = "0.13" log = "^0.4" mockito = "^1" murmur3 = "0.5.2" +num_cpus = "1" once_cell = "1" opendal = "0.48" ordered-float = "4.0.0" diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml index de5b7cdc5e..18e25e9658 100644 --- a/crates/iceberg/Cargo.toml +++ b/crates/iceberg/Cargo.toml @@ -61,6 +61,7 @@ fnv = { workspace = true } futures = { workspace = true } itertools = { workspace = true } murmur3 = { workspace = true } +num_cpus = { workspace = true } once_cell = { workspace = true } opendal = { workspace = true } ordered-float = { workspace = true } diff --git a/crates/iceberg/src/error.rs b/crates/iceberg/src/error.rs index 6f7fd7cf97..2b69b4706f 100644 --- a/crates/iceberg/src/error.rs +++ b/crates/iceberg/src/error.rs @@ -331,6 +331,12 @@ define_from_err!( "Failed to read a Parquet file" ); +define_from_err!( + futures::channel::mpsc::SendError, + ErrorKind::Unexpected, + "Failed to send a message to a channel" +); + define_from_err!(std::io::Error, ErrorKind::Unexpected, "IO Operation failed"); /// Converts a timestamp in milliseconds to `DateTime`, handling errors. diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs index 18489b7216..ebb3b007cd 100644 --- a/crates/iceberg/src/scan.rs +++ b/crates/iceberg/src/scan.rs @@ -17,14 +17,13 @@ //! Table scan api. -use std::collections::hash_map::Entry; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use arrow_array::RecordBatch; -use async_stream::try_stream; +use futures::channel::mpsc::{channel, Sender}; use futures::stream::BoxStream; -use futures::StreamExt; +use futures::{SinkExt, StreamExt, TryFutureExt, TryStreamExt}; use serde::{Deserialize, Serialize}; use crate::arrow::ArrowReaderBuilder; @@ -34,9 +33,10 @@ use crate::expr::visitors::inclusive_projection::InclusiveProjection; use crate::expr::visitors::manifest_evaluator::ManifestEvaluator; use crate::expr::{Bind, BoundPredicate, Predicate}; use crate::io::FileIO; +use crate::runtime::spawn; use crate::spec::{ - DataContentType, ManifestContentType, ManifestFile, Schema, SchemaRef, SnapshotRef, - TableMetadataRef, + DataContentType, ManifestContentType, ManifestEntryRef, ManifestFile, ManifestList, Schema, + SchemaRef, SnapshotRef, TableMetadataRef, }; use crate::table::Table; use crate::{Error, ErrorKind, Result}; @@ -55,10 +55,14 @@ pub struct TableScanBuilder<'a> { batch_size: Option, case_sensitive: bool, filter: Option, + concurrency_limit_manifest_files: usize, + concurrency_limit_manifest_entries: usize, } impl<'a> TableScanBuilder<'a> { pub(crate) fn new(table: &'a Table) -> Self { + let num_cpus = num_cpus::get(); + Self { table, column_names: vec![], @@ -66,6 +70,8 @@ impl<'a> TableScanBuilder<'a> { batch_size: None, case_sensitive: true, filter: None, + concurrency_limit_manifest_files: num_cpus, + concurrency_limit_manifest_entries: num_cpus, } } @@ -111,6 +117,26 @@ impl<'a> TableScanBuilder<'a> { self } + /// Sets the concurrency limit for both manifest files and manifest + /// entries for this scan + pub fn with_concurrency_limit(mut self, limit: usize) -> Self { + self.concurrency_limit_manifest_files = limit; + self.concurrency_limit_manifest_entries = limit; + self + } + + /// Sets the manifest file concurrency limit for this scan + pub fn with_manifest_file_concurrency_limit(mut self, limit: usize) -> Self { + self.concurrency_limit_manifest_files = limit; + self + } + + /// Sets the manifest entry concurrency limit for this scan + pub fn with_manifest_entry_concurrency_limit(mut self, limit: usize) -> Self { + self.concurrency_limit_manifest_entries = limit; + self + } + /// Build the table scan. pub fn build(self) -> Result { let snapshot = match self.snapshot_id { @@ -155,12 +181,6 @@ impl<'a> TableScanBuilder<'a> { } } - let bound_predicates = if let Some(ref predicates) = self.filter { - Some(predicates.bind(schema.clone(), true)?) - } else { - None - }; - let mut field_ids = vec![]; for column_name in &self.column_names { let field_id = schema.field_id_by_name(column_name).ok_or_else(|| { @@ -199,17 +219,33 @@ impl<'a> TableScanBuilder<'a> { field_ids.push(field_id); } - Ok(TableScan { + let snapshot_bound_predicate = if let Some(ref predicates) = self.filter { + Some(predicates.bind(schema.clone(), true)?) + } else { + None + }; + + let plan_context = PlanContext { snapshot, - file_io: self.table.file_io().clone(), table_metadata: self.table.metadata_ref(), - column_names: self.column_names, - field_ids, - bound_predicates, - schema, - batch_size: self.batch_size, + snapshot_schema: schema, case_sensitive: self.case_sensitive, - filter: self.filter.map(Arc::new), + predicate: self.filter.map(Arc::new), + snapshot_bound_predicate: snapshot_bound_predicate.map(Arc::new), + file_io: self.table.file_io().clone(), + field_ids: Arc::new(field_ids), + partition_filter_cache: Arc::new(PartitionFilterCache::new()), + manifest_evaluator_cache: Arc::new(ManifestEvaluatorCache::new()), + expression_evaluator_cache: Arc::new(ExpressionEvaluatorCache::new()), + }; + + Ok(TableScan { + batch_size: self.batch_size, + column_names: self.column_names, + concurrency_limit_manifest_files: self.concurrency_limit_manifest_files, + file_io: self.table.file_io().clone(), + plan_context, + concurrency_limit_manifest_entries: self.concurrency_limit_manifest_entries, }) } } @@ -217,116 +253,85 @@ impl<'a> TableScanBuilder<'a> { /// Table scan. #[derive(Debug)] pub struct TableScan { - snapshot: SnapshotRef, - table_metadata: TableMetadataRef, + plan_context: PlanContext, + batch_size: Option, file_io: FileIO, column_names: Vec, - field_ids: Vec, - bound_predicates: Option, - schema: SchemaRef, - batch_size: Option, + /// The maximum number of manifest files that will be + /// retrieved from [`FileIO`] concurrently + concurrency_limit_manifest_files: usize, + + /// The maximum number of [`ManifestEntry`]s that will + /// be processed in parallel + concurrency_limit_manifest_entries: usize, +} + +/// PlanContext wraps a [`SnapshotRef`] alongside all the other +/// objects that are required to perform a scan file plan. +#[derive(Debug)] +struct PlanContext { + snapshot: SnapshotRef, + + table_metadata: TableMetadataRef, + snapshot_schema: SchemaRef, case_sensitive: bool, - filter: Option>, + predicate: Option>, + snapshot_bound_predicate: Option>, + file_io: FileIO, + field_ids: Arc>, + + partition_filter_cache: Arc, + manifest_evaluator_cache: Arc, + expression_evaluator_cache: Arc, } impl TableScan { /// Returns a stream of [`FileScanTask`]s. pub async fn plan_files(&self) -> Result { - let context = FileScanStreamContext::new( - self.schema.clone(), - self.snapshot.clone(), - self.table_metadata.clone(), - self.file_io.clone(), - self.filter.clone(), - self.case_sensitive, - )?; - - let mut partition_filter_cache = PartitionFilterCache::new(); - let mut manifest_evaluator_cache = ManifestEvaluatorCache::new(); - let mut expression_evaluator_cache = ExpressionEvaluatorCache::new(); - - let field_ids = self.field_ids.clone(); - let bound_predicates = self.bound_predicates.clone(); - - Ok(try_stream! { - let manifest_list = context - .snapshot - .load_manifest_list(&context.file_io, &context.table_metadata) - .await?; - - for entry in manifest_list.entries() { - if !Self::content_type_is_data(entry) { - continue; - } - - let partition_spec_id = entry.partition_spec_id; - - let partition_filter = partition_filter_cache.get( - partition_spec_id, - &context, - )?; - - if let Some(partition_filter) = partition_filter { - let manifest_evaluator = manifest_evaluator_cache.get( - partition_spec_id, - partition_filter, - ); - - if !manifest_evaluator.eval(entry)? { - continue; - } - } + let concurrency_limit_manifest_files = self.concurrency_limit_manifest_files; + let concurrency_limit_manifest_entries = self.concurrency_limit_manifest_entries; + + // used to stream ManifestEntryContexts between stages of the file plan operation + let (manifest_entry_ctx_tx, manifest_entry_ctx_rx) = + channel(concurrency_limit_manifest_files); + // used to stream the results back to the caller + let (file_scan_task_tx, file_scan_task_rx) = channel(concurrency_limit_manifest_entries); + + let manifest_list = self.plan_context.get_manifest_list().await?; + + // get the [`ManifestFile`]s from the [`ManifestList`], filtering out any + // whose content type is not Data or whose partitions cannot match this + // scan's filter + let manifest_file_contexts = self + .plan_context + .build_manifest_file_contexts(manifest_list, manifest_entry_ctx_tx)?; + + // Concurrently load all [`Manifest`]s and stream their [`ManifestEntry`]s + spawn(async move { + futures::stream::iter(manifest_file_contexts) + .try_for_each_concurrent(concurrency_limit_manifest_files, |ctx| async move { + ctx.fetch_manifest_and_stream_manifest_entries().await + }) + .await + }); + + // Process the [`ManifestEntry`] stream in parallel + spawn(async move { + manifest_entry_ctx_rx + .map(|me_ctx| Ok((me_ctx, file_scan_task_tx.clone()))) + .try_for_each_concurrent( + concurrency_limit_manifest_entries, + |(manifest_entry_context, tx)| async move { + crate::runtime::spawn(async move { + Self::process_manifest_entry(manifest_entry_context, tx).await + }) + .await + }, + ) + .await + }); - let manifest = entry.load_manifest(&context.file_io).await?; - let mut manifest_entries_stream = - futures::stream::iter(manifest.entries().iter().filter(|e| e.is_alive())); - - while let Some(manifest_entry) = manifest_entries_stream.next().await { - let data_file = manifest_entry.data_file(); - - if let Some(partition_filter) = partition_filter { - let expression_evaluator = expression_evaluator_cache.get(partition_spec_id, partition_filter); - - if !expression_evaluator.eval(data_file)? { - continue; - } - } - - - if let Some(bound_predicate) = context.bound_filter() { - // reject any manifest entries whose data file's metrics don't match the filter. - if !InclusiveMetricsEvaluator::eval( - bound_predicate, - manifest_entry.data_file(), - false - )? { - continue; - } - } - - match manifest_entry.content_type() { - DataContentType::EqualityDeletes | DataContentType::PositionDeletes => { - yield Err(Error::new( - ErrorKind::FeatureUnsupported, - "Delete files are not supported yet.", - ))?; - } - DataContentType::Data => { - let scan_task: Result = Ok(FileScanTask { - data_file_path: manifest_entry.data_file().file_path().to_string(), - start: 0, - length: manifest_entry.file_size_in_bytes(), - project_field_ids: field_ids.clone(), - predicate: bound_predicates.clone(), - schema: context.schema.clone(), - }); - yield scan_task?; - } - } - } - } - } - .boxed()) + return Ok(file_scan_task_rx.boxed()); } /// Returns an [`ArrowRecordBatchStream`]. @@ -340,157 +345,468 @@ impl TableScan { arrow_reader_builder.build().read(self.plan_files().await?) } - /// Checks whether the [`ManifestContentType`] is `Data` or not. - fn content_type_is_data(entry: &ManifestFile) -> bool { - if let ManifestContentType::Data = entry.content { - return true; - } - false - } - /// Returns a reference to the column names of the table scan. pub fn column_names(&self) -> &[String] { &self.column_names } + /// Returns a reference to the snapshot of the table scan. + pub fn snapshot(&self) -> &SnapshotRef { + &self.plan_context.snapshot + } + + async fn process_manifest_entry( + manifest_entry_context: ManifestEntryContext, + mut file_scan_task_tx: Sender>, + ) -> Result<()> { + // skip processing this manifest entry if it has been marked as deleted + if !manifest_entry_context.manifest_entry.is_alive() { + return Ok(()); + } + + // abort the plan if we encounter a manifest entry whose data file's + // content type is currently unsupported + if manifest_entry_context.manifest_entry.content_type() != DataContentType::Data { + return Err(Error::new( + ErrorKind::FeatureUnsupported, + "Only Data files currently supported", + )); + } + + if let Some(ref bound_predicates) = manifest_entry_context.bound_predicates { + let BoundPredicates { + ref snapshot_bound_predicate, + ref partition_bound_predicate, + } = bound_predicates.as_ref(); + + let expression_evaluator_cache = + manifest_entry_context.expression_evaluator_cache.as_ref(); + + let expression_evaluator = expression_evaluator_cache.get( + manifest_entry_context.partition_spec_id, + partition_bound_predicate, + )?; + + // skip any data file whose partition data indicates that it can't contain + // any data that matches this scan's filter + if !expression_evaluator.eval(manifest_entry_context.manifest_entry.data_file())? { + return Ok(()); + } + + // skip any data file whose metrics don't match this scan's filter + if !InclusiveMetricsEvaluator::eval( + snapshot_bound_predicate, + manifest_entry_context.manifest_entry.data_file(), + false, + )? { + return Ok(()); + } + } + + // congratulations! the manifest entry has made its way through the + // entire plan without getting filtered out. Create a corresponding + // FileScanTask and push it to the result stream + file_scan_task_tx + .send(Ok(manifest_entry_context.into_file_scan_task())) + .await?; + + Ok(()) + } } -/// Holds the context necessary for file scanning operations -/// in a streaming environment. -#[derive(Debug)] -struct FileScanStreamContext { - schema: SchemaRef, - snapshot: SnapshotRef, - table_metadata: TableMetadataRef, +struct BoundPredicates { + partition_bound_predicate: BoundPredicate, + snapshot_bound_predicate: BoundPredicate, +} + +/// Wraps a [`ManifestFile`] alongside the objects that are needed +/// to process it in a thread-safe manner +struct ManifestFileContext { + manifest_file: ManifestFile, + + sender: Sender, + + field_ids: Arc>, file_io: FileIO, - bound_filter: Option, - case_sensitive: bool, + bound_predicates: Option>, + snapshot_schema: SchemaRef, + expression_evaluator_cache: Arc, } -impl FileScanStreamContext { - /// Creates a new [`FileScanStreamContext`]. - fn new( - schema: SchemaRef, - snapshot: SnapshotRef, - table_metadata: TableMetadataRef, - file_io: FileIO, - filter: Option>, - case_sensitive: bool, - ) -> Result { - let bound_filter = match filter { - Some(ref filter) => Some(filter.bind(schema.clone(), case_sensitive)?), - None => None, - }; +/// Wraps a [`ManifestEntryRef`] alongside the objects that are needed +/// to process it in a thread-safe manner +struct ManifestEntryContext { + manifest_entry: ManifestEntryRef, - Ok(Self { - schema, - snapshot, - table_metadata, + expression_evaluator_cache: Arc, + field_ids: Arc>, + bound_predicates: Option>, + partition_spec_id: i32, + snapshot_schema: SchemaRef, +} + +impl ManifestFileContext { + /// Consumes this [`ManifestFileContext`], fetching its Manifest from FileIO and then + /// streaming its constituent [`ManifestEntries`] to the channel provided in the context + async fn fetch_manifest_and_stream_manifest_entries(self) -> Result<()> { + let ManifestFileContext { file_io, - bound_filter, - case_sensitive, - }) + manifest_file, + bound_predicates, + snapshot_schema, + field_ids, + expression_evaluator_cache, + mut sender, + .. + } = self; + + let file_io_cloned = file_io.clone(); + let manifest = manifest_file.load_manifest(&file_io_cloned).await?; + + let (entries, _) = manifest.consume(); + + for manifest_entry in entries.into_iter() { + let manifest_entry_context = ManifestEntryContext { + manifest_entry, + expression_evaluator_cache: expression_evaluator_cache.clone(), + field_ids: field_ids.clone(), + partition_spec_id: manifest_file.partition_spec_id, + bound_predicates: bound_predicates.clone(), + snapshot_schema: snapshot_schema.clone(), + }; + + sender + .send(manifest_entry_context) + .map_err(|_| Error::new(ErrorKind::Unexpected, "mpsc channel SendError")) + .await?; + } + + Ok(()) } +} - /// Returns a reference to the [`BoundPredicate`] filter. - fn bound_filter(&self) -> Option<&BoundPredicate> { - self.bound_filter.as_ref() +impl ManifestEntryContext { + /// consume this `ManifestEntryContext`, returning a `FileScanTask` + /// created from it + fn into_file_scan_task(self) -> FileScanTask { + FileScanTask { + data_file_path: self.manifest_entry.file_path().to_string(), + start: 0, + length: self.manifest_entry.file_size_in_bytes(), + project_field_ids: self.field_ids.to_vec(), + predicate: self + .bound_predicates + .map(|x| x.as_ref().snapshot_bound_predicate.clone()), + schema: self.snapshot_schema, + } + } +} + +impl PlanContext { + async fn get_manifest_list(&self) -> Result { + self.snapshot + .load_manifest_list(&self.file_io, &self.table_metadata) + .await + } + + fn get_partition_filter(&self, manifest_file: &ManifestFile) -> Result> { + let partition_spec_id = manifest_file.partition_spec_id; + + let partition_filter = self.partition_filter_cache.get( + partition_spec_id, + &self.table_metadata, + &self.snapshot_schema, + self.case_sensitive, + self.predicate + .as_ref() + .ok_or(Error::new( + ErrorKind::Unexpected, + "Expected a predicate but none present", + ))? + .as_ref() + .bind(self.snapshot_schema.clone(), self.case_sensitive)?, + )?; + + Ok(partition_filter) + } + + fn build_manifest_file_contexts( + &self, + manifest_list: ManifestList, + sender: Sender, + ) -> Result>>> { + let filtered_entries = manifest_list + .consume_entries() + .into_iter() + .filter(|manifest_file| manifest_file.content == ManifestContentType::Data); + + // TODO: Ideally we could ditch this intermediate Vec as we return an iterator. + let mut filtered_mfcs = vec![]; + if self.predicate.is_some() { + for manifest_file in filtered_entries { + let partition_bound_predicate = self.get_partition_filter(&manifest_file)?; + + // evaluate the ManifestFile against the partition filter. Skip + // if it cannot contain any matching rows + if self + .manifest_evaluator_cache + .get( + manifest_file.partition_spec_id, + partition_bound_predicate.clone(), + ) + .eval(&manifest_file)? + { + let mfc = self.create_manifest_file_context( + manifest_file, + Some(partition_bound_predicate), + sender.clone(), + ); + filtered_mfcs.push(Ok(mfc)); + } + } + } else { + for manifest_file in filtered_entries { + let mfc = self.create_manifest_file_context(manifest_file, None, sender.clone()); + filtered_mfcs.push(Ok(mfc)); + } + } + + Ok(Box::new(filtered_mfcs.into_iter())) + } + + fn create_manifest_file_context( + &self, + manifest_file: ManifestFile, + partition_filter: Option>, + sender: Sender, + ) -> ManifestFileContext { + let bound_predicates = + if let (Some(ref partition_bound_predicate), Some(snapshot_bound_predicate)) = + (partition_filter, &self.snapshot_bound_predicate) + { + Some(Arc::new(BoundPredicates { + partition_bound_predicate: partition_bound_predicate.as_ref().clone(), + snapshot_bound_predicate: snapshot_bound_predicate.as_ref().clone(), + })) + } else { + None + }; + + ManifestFileContext { + manifest_file, + bound_predicates, + sender, + file_io: self.file_io.clone(), + snapshot_schema: self.snapshot_schema.clone(), + field_ids: self.field_ids.clone(), + expression_evaluator_cache: self.expression_evaluator_cache.clone(), + } } } /// Manages the caching of [`BoundPredicate`] objects /// for [`PartitionSpec`]s based on partition spec id. #[derive(Debug)] -struct PartitionFilterCache(HashMap); +struct PartitionFilterCache(RwLock>>); impl PartitionFilterCache { /// Creates a new [`PartitionFilterCache`] /// with an empty internal HashMap. fn new() -> Self { - Self(HashMap::new()) + Self(RwLock::new(HashMap::new())) } /// Retrieves a [`BoundPredicate`] from the cache /// or computes it if not present. fn get( - &mut self, + &self, spec_id: i32, - context: &FileScanStreamContext, - ) -> Result> { - match context.bound_filter() { - None => Ok(None), - Some(filter) => match self.0.entry(spec_id) { - Entry::Occupied(e) => Ok(Some(e.into_mut())), - Entry::Vacant(e) => { - let partition_spec = context - .table_metadata - .partition_spec_by_id(spec_id) - .ok_or(Error::new( - ErrorKind::Unexpected, - format!("Could not find partition spec for id {}", spec_id), - ))?; - - let partition_type = partition_spec.partition_type(context.schema.as_ref())?; - let partition_fields = partition_type.fields().to_owned(); - let partition_schema = Arc::new( - Schema::builder() - .with_schema_id(partition_spec.spec_id) - .with_fields(partition_fields) - .build()?, - ); + table_metadata: &TableMetadataRef, + schema: &SchemaRef, + case_sensitive: bool, + filter: BoundPredicate, + ) -> Result> { + // we need a block here to ensure that the `read()` gets dropped before we hit the `write()` + // below, otherwise we hit deadlock + { + let read = self.0.read().map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "PartitionFilterCache RwLock was poisoned", + ) + })?; - let mut inclusive_projection = InclusiveProjection::new(partition_spec.clone()); + if read.contains_key(&spec_id) { + return Ok(read.get(&spec_id).unwrap().clone()); + } + } - let partition_filter = inclusive_projection - .project(filter)? - .rewrite_not() - .bind(partition_schema.clone(), context.case_sensitive)?; + let partition_spec = table_metadata + .partition_spec_by_id(spec_id) + .ok_or(Error::new( + ErrorKind::Unexpected, + format!("Could not find partition spec for id {}", spec_id), + ))?; - Ok(Some(e.insert(partition_filter))) - } - }, - } + let partition_type = partition_spec.partition_type(schema.as_ref())?; + let partition_fields = partition_type.fields().to_owned(); + let partition_schema = Arc::new( + Schema::builder() + .with_schema_id(partition_spec.spec_id) + .with_fields(partition_fields) + .build()?, + ); + + let mut inclusive_projection = InclusiveProjection::new(partition_spec.clone()); + + let partition_filter = inclusive_projection + .project(&filter)? + .rewrite_not() + .bind(partition_schema.clone(), case_sensitive)?; + + self.0 + .write() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "PartitionFilterCache RwLock was poisoned", + ) + })? + .insert(spec_id, Arc::new(partition_filter)); + + let read = self.0.read().map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "PartitionFilterCache RwLock was poisoned", + ) + })?; + + Ok(read.get(&spec_id).unwrap().clone()) } } /// Manages the caching of [`ManifestEvaluator`] objects /// for [`PartitionSpec`]s based on partition spec id. #[derive(Debug)] -struct ManifestEvaluatorCache(HashMap); +struct ManifestEvaluatorCache(RwLock>>); impl ManifestEvaluatorCache { /// Creates a new [`ManifestEvaluatorCache`] /// with an empty internal HashMap. fn new() -> Self { - Self(HashMap::new()) + Self(RwLock::new(HashMap::new())) } /// Retrieves a [`ManifestEvaluator`] from the cache /// or computes it if not present. - fn get(&mut self, spec_id: i32, partition_filter: &BoundPredicate) -> &mut ManifestEvaluator { + fn get(&self, spec_id: i32, partition_filter: Arc) -> Arc { + // we need a block here to ensure that the `read()` gets dropped before we hit the `write()` + // below, otherwise we hit deadlock + { + let read = self + .0 + .read() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "ManifestEvaluatorCache RwLock was poisoned", + ) + }) + .unwrap(); + + if read.contains_key(&spec_id) { + return read.get(&spec_id).unwrap().clone(); + } + } + self.0 - .entry(spec_id) - .or_insert(ManifestEvaluator::new(partition_filter.clone())) + .write() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "ManifestEvaluatorCache RwLock was poisoned", + ) + }) + .unwrap() + .insert( + spec_id, + Arc::new(ManifestEvaluator::new(partition_filter.as_ref().clone())), + ); + + let read = self + .0 + .read() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "ManifestEvaluatorCache RwLock was poisoned", + ) + }) + .unwrap(); + + read.get(&spec_id).unwrap().clone() } } /// Manages the caching of [`ExpressionEvaluator`] objects /// for [`PartitionSpec`]s based on partition spec id. #[derive(Debug)] -struct ExpressionEvaluatorCache(HashMap); +struct ExpressionEvaluatorCache(RwLock>>); impl ExpressionEvaluatorCache { /// Creates a new [`ExpressionEvaluatorCache`] /// with an empty internal HashMap. fn new() -> Self { - Self(HashMap::new()) + Self(RwLock::new(HashMap::new())) } /// Retrieves a [`ExpressionEvaluator`] from the cache /// or computes it if not present. - fn get(&mut self, spec_id: i32, partition_filter: &BoundPredicate) -> &mut ExpressionEvaluator { + fn get( + &self, + spec_id: i32, + partition_filter: &BoundPredicate, + ) -> Result> { + // we need a block here to ensure that the `read()` gets dropped before we hit the `write()` + // below, otherwise we hit deadlock + { + let read = self.0.read().map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "PartitionFilterCache RwLock was poisoned", + ) + })?; + + if read.contains_key(&spec_id) { + return Ok(read.get(&spec_id).unwrap().clone()); + } + } + self.0 - .entry(spec_id) - .or_insert(ExpressionEvaluator::new(partition_filter.clone())) + .write() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "ManifestEvaluatorCache RwLock was poisoned", + ) + }) + .unwrap() + .insert( + spec_id, + Arc::new(ExpressionEvaluator::new(partition_filter.clone())), + ); + + let read = self + .0 + .read() + .map_err(|_| { + Error::new( + ErrorKind::Unexpected, + "ManifestEvaluatorCache RwLock was poisoned", + ) + }) + .unwrap(); + + Ok(read.get(&spec_id).unwrap().clone()) } } @@ -817,7 +1133,7 @@ mod tests { let table_scan = table.scan().build().unwrap(); assert_eq!( table.metadata().current_snapshot().unwrap().snapshot_id(), - table_scan.snapshot.snapshot_id() + table_scan.snapshot().snapshot_id() ); } @@ -838,7 +1154,7 @@ mod tests { .snapshot_id(3051729675574597004) .build() .unwrap(); - assert_eq!(table_scan.snapshot.snapshot_id(), 3051729675574597004); + assert_eq!(table_scan.snapshot().snapshot_id(), 3051729675574597004); } #[tokio::test] diff --git a/crates/iceberg/src/spec/manifest.rs b/crates/iceberg/src/spec/manifest.rs index e08591f9e9..e2f8251c15 100644 --- a/crates/iceberg/src/spec/manifest.rs +++ b/crates/iceberg/src/spec/manifest.rs @@ -94,6 +94,12 @@ impl Manifest { &self.entries } + /// Consume this Manifest, returning its constituent parts + pub fn consume(self) -> (Vec, ManifestMetadata) { + let Self { entries, metadata } = self; + (entries, metadata) + } + /// Constructor from [`ManifestMetadata`] and [`ManifestEntry`]s. pub fn new(metadata: ManifestMetadata, entries: Vec) -> Self { Self { diff --git a/crates/iceberg/src/spec/manifest_list.rs b/crates/iceberg/src/spec/manifest_list.rs index e818890680..3aaecf12d2 100644 --- a/crates/iceberg/src/spec/manifest_list.rs +++ b/crates/iceberg/src/spec/manifest_list.rs @@ -78,6 +78,11 @@ impl ManifestList { pub fn entries(&self) -> &[ManifestFile] { &self.entries } + + /// Take ownership of the entries in the manifest list, consuming it + pub fn consume_entries(self) -> impl IntoIterator { + Box::new(self.entries.into_iter()) + } } /// A manifest list writer.