Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,10 @@ jobs:
run: cargo clippy --workspace --all-targets -- -D warnings
- name: Test
run: cargo nextest run --workspace --all-targets --no-fail-fast
- name: Test Manual Registration
- name: Test Manual Registration / no-default-features
run: cargo nextest run --workspace --tests --no-fail-fast --no-default-features --features macros
- name: Test docs
run: cargo test --workspace --doc
- name: Check (without default features)
run: cargo check --workspace --no-default-features

miri:
name: Miri
Expand Down
12 changes: 11 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ thin-vec = "0.2.13"
shuttle = { version = "0.8.0", optional = true }

[features]
default = ["salsa_unstable", "rayon", "macros", "inventory"]
default = ["salsa_unstable", "rayon", "macros", "inventory", "accumulator"]
inventory = ["dep:inventory"]
shuttle = ["dep:shuttle"]
accumulator = ["salsa-macro-rules/accumulator"]
# FIXME: remove `salsa_unstable` before 1.0.
salsa_unstable = []
macros = ["dep:salsa-macros"]
Expand Down Expand Up @@ -82,11 +83,20 @@ harness = false
[[bench]]
name = "accumulator"
harness = false
required-features = ["accumulator"]

[[bench]]
name = "dataflow"
harness = false

[[example]]
name = "lazy-input"
required-features = ["accumulator"]

[[example]]
name = "calc"
required-features = ["accumulator"]

[workspace]
members = ["components/salsa-macro-rules", "components/salsa-macros"]

Expand Down
3 changes: 3 additions & 0 deletions components/salsa-macro-rules/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ rust-version.workspace = true
description = "Declarative macros for the salsa crate"

[dependencies]

[features]
accumulator = []
13 changes: 13 additions & 0 deletions components/salsa-macro-rules/src/gate_accumulated.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#[cfg(feature = "accumulator")]
#[macro_export]
macro_rules! gate_accumulated {
($($body:tt)*) => {
$($body)*
};
}

#[cfg(not(feature = "accumulator"))]
#[macro_export]
macro_rules! gate_accumulated {
($($body:tt)*) => {};
}
2 changes: 2 additions & 0 deletions components/salsa-macro-rules/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
//! from a submodule is to use multiple crates, hence the existence
//! of this crate.

mod gate_accumulated;
mod macro_if;
mod maybe_backdate;
mod maybe_default;
mod return_mode;
#[cfg(feature = "accumulator")]
mod setup_accumulator_impl;
mod setup_input_struct;
mod setup_interned_struct;
Expand Down
30 changes: 16 additions & 14 deletions components/salsa-macro-rules/src/setup_tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,21 +343,23 @@ macro_rules! setup_tracked_fn {

#[allow(non_local_definitions)]
impl $fn_name {
pub fn accumulated<$db_lt, A: salsa::Accumulator>(
$db: &$db_lt dyn $Db,
$($input_id: $interned_input_ty,)*
) -> Vec<&$db_lt A> {
use salsa::plumbing as $zalsa;
let key = $zalsa::macro_if! {
if $needs_interner {{
let (zalsa, zalsa_local) = $db.zalsas();
$Configuration::intern_ingredient($db).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data)
}} else {
$zalsa::AsId::as_id(&($($input_id),*))
}
};
$zalsa::gate_accumulated! {
pub fn accumulated<$db_lt, A: salsa::Accumulator>(
$db: &$db_lt dyn $Db,
$($input_id: $interned_input_ty,)*
) -> Vec<&$db_lt A> {
use salsa::plumbing as $zalsa;
let key = $zalsa::macro_if! {
if $needs_interner {{
let (zalsa, zalsa_local) = $db.zalsas();
$Configuration::intern_ingredient($db).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data)
}} else {
$zalsa::AsId::as_id(&($($input_id),*))
}
};

$Configuration::fn_ingredient($db).accumulated_by::<A>($db, key)
$Configuration::fn_ingredient($db).accumulated_by::<A>($db, key)
}
}

$zalsa::macro_if! { $is_specifiable =>
Expand Down
73 changes: 48 additions & 25 deletions src/active_query.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::{fmt, mem, ops};

use crate::accumulator::accumulated_map::{
AccumulatedMap, AtomicInputAccumulatedValues, InputAccumulatedValues,
#[cfg(feature = "accumulator")]
use crate::accumulator::{
accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues, InputAccumulatedValues},
Accumulator,
};
use crate::cycle::{CycleHeads, IterationCount};
use crate::durability::Durability;
Expand All @@ -11,7 +13,7 @@ use crate::runtime::Stamp;
use crate::sync::atomic::AtomicBool;
use crate::tracked_struct::{Disambiguator, DisambiguatorMap, IdentityHash, IdentityMap};
use crate::zalsa_local::{QueryEdge, QueryOrigin, QueryRevisions, QueryRevisionsExtra};
use crate::{Accumulator, IngredientIndex, Revision};
use crate::Revision;

#[derive(Debug)]
pub(crate) struct ActiveQuery {
Expand Down Expand Up @@ -51,10 +53,12 @@ pub(crate) struct ActiveQuery {

/// Stores the values accumulated to the given ingredient.
/// The type of accumulated value is erased but known to the ingredient.
#[cfg(feature = "accumulator")]
accumulated: AccumulatedMap,

/// [`InputAccumulatedValues::Empty`] if any input read during the query's execution
/// has any accumulated values.
#[cfg(feature = "accumulator")]
accumulated_inputs: InputAccumulatedValues,

/// Provisional cycle results that this query depends on.
Expand Down Expand Up @@ -84,18 +88,21 @@ impl ActiveQuery {
input: DatabaseKeyIndex,
durability: Durability,
changed_at: Revision,
has_accumulated: bool,
accumulated_inputs: &AtomicInputAccumulatedValues,
cycle_heads: &CycleHeads,
#[cfg(feature = "accumulator")] has_accumulated: bool,
#[cfg(feature = "accumulator")] accumulated_inputs: &AtomicInputAccumulatedValues,
) {
self.durability = self.durability.min(durability);
self.changed_at = self.changed_at.max(changed_at);
self.input_outputs.insert(QueryEdge::input(input));
self.accumulated_inputs = self.accumulated_inputs.or_else(|| match has_accumulated {
true => InputAccumulatedValues::Any,
false => accumulated_inputs.load(),
});
self.cycle_heads.extend(cycle_heads);
#[cfg(feature = "accumulator")]
{
self.accumulated_inputs = self.accumulated_inputs.or_else(|| match has_accumulated {
true => InputAccumulatedValues::Any,
false => accumulated_inputs.load(),
});
}
}

pub(super) fn add_read_simple(
Expand All @@ -121,7 +128,8 @@ impl ActiveQuery {
self.changed_at = self.changed_at.max(revision);
}

pub(super) fn accumulate(&mut self, index: IngredientIndex, value: impl Accumulator) {
#[cfg(feature = "accumulator")]
pub(super) fn accumulate(&mut self, index: crate::IngredientIndex, value: impl Accumulator) {
self.accumulated.accumulate(index, value);
}

Expand Down Expand Up @@ -169,10 +177,12 @@ impl ActiveQuery {
untracked_read: false,
disambiguator_map: Default::default(),
tracked_struct_ids: Default::default(),
accumulated: Default::default(),
accumulated_inputs: Default::default(),
cycle_heads: Default::default(),
iteration_count,
#[cfg(feature = "accumulator")]
accumulated: Default::default(),
#[cfg(feature = "accumulator")]
accumulated_inputs: Default::default(),
}
}

Expand All @@ -185,10 +195,12 @@ impl ActiveQuery {
untracked_read,
ref mut disambiguator_map,
ref mut tracked_struct_ids,
ref mut accumulated,
accumulated_inputs,
ref mut cycle_heads,
iteration_count,
#[cfg(feature = "accumulator")]
ref mut accumulated,
#[cfg(feature = "accumulator")]
accumulated_inputs,
} = self;

let origin = if untracked_read {
Expand All @@ -198,19 +210,22 @@ impl ActiveQuery {
};
disambiguator_map.clear();

#[cfg(feature = "accumulator")]
let accumulated_inputs = AtomicInputAccumulatedValues::new(accumulated_inputs);
let verified_final = cycle_heads.is_empty();
let extra = QueryRevisionsExtra::new(
#[cfg(feature = "accumulator")]
mem::take(accumulated),
mem::take(tracked_struct_ids),
mem::take(cycle_heads),
iteration_count,
);
let accumulated_inputs = AtomicInputAccumulatedValues::new(accumulated_inputs);

QueryRevisions {
changed_at,
durability,
origin,
#[cfg(feature = "accumulator")]
accumulated_inputs,
verified_final: AtomicBool::new(verified_final),
extra,
Expand All @@ -226,17 +241,20 @@ impl ActiveQuery {
untracked_read: _,
disambiguator_map,
tracked_struct_ids,
accumulated,
accumulated_inputs: _,
cycle_heads,
iteration_count,
#[cfg(feature = "accumulator")]
accumulated,
#[cfg(feature = "accumulator")]
accumulated_inputs: _,
} = self;
input_outputs.clear();
disambiguator_map.clear();
tracked_struct_ids.clear();
accumulated.clear();
*cycle_heads = Default::default();
*iteration_count = IterationCount::initial();
#[cfg(feature = "accumulator")]
accumulated.clear();
}

fn reset_for(
Expand All @@ -252,16 +270,17 @@ impl ActiveQuery {
untracked_read,
disambiguator_map,
tracked_struct_ids,
accumulated,
accumulated_inputs,
cycle_heads,
iteration_count,
#[cfg(feature = "accumulator")]
accumulated,
#[cfg(feature = "accumulator")]
accumulated_inputs,
} = self;
*database_key_index = new_database_key_index;
*durability = Durability::MAX;
*changed_at = Revision::start();
*untracked_read = false;
*accumulated_inputs = Default::default();
*iteration_count = new_iteration_count;
debug_assert!(
input_outputs.is_empty(),
Expand All @@ -279,10 +298,14 @@ impl ActiveQuery {
cycle_heads.is_empty(),
"`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called"
);
debug_assert!(
accumulated.is_empty(),
"`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called"
);
#[cfg(feature = "accumulator")]
{
*accumulated_inputs = Default::default();
debug_assert!(
accumulated.is_empty(),
"`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called"
);
}
}
}

Expand Down
8 changes: 6 additions & 2 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::sync::atomic::Ordering;
use std::sync::OnceLock;
pub(crate) use sync::SyncGuard;

use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues};
use crate::cycle::{
empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, ProvisionalStatus,
};
Expand All @@ -25,6 +24,7 @@ use crate::zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa};
use crate::zalsa_local::QueryOriginRef;
use crate::{Id, Revision};

#[cfg(feature = "accumulator")]
mod accumulated;
mod backdate;
mod delete;
Expand Down Expand Up @@ -371,11 +371,15 @@ where
C::CYCLE_STRATEGY
}

#[cfg(feature = "accumulator")]
unsafe fn accumulated<'db>(
&'db self,
db: RawDatabase<'db>,
key_index: Id,
) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) {
) -> (
Option<&'db crate::accumulator::accumulated_map::AccumulatedMap>,
crate::accumulator::accumulated_map::InputAccumulatedValues,
) {
// SAFETY: The `db` belongs to the ingredient as per caller invariant
let db = unsafe { self.view_caster().downcast_unchecked(db) };
self.accumulated_map(db, key_index)
Expand Down
6 changes: 4 additions & 2 deletions src/function/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ where
database_key_index,
memo.revisions.durability,
memo.revisions.changed_at,
memo.cycle_heads(),
#[cfg(feature = "accumulator")]
memo.revisions.accumulated().is_some(),
#[cfg(feature = "accumulator")]
&memo.revisions.accumulated_inputs,
memo.cycle_heads(),
);

memo_value
Expand Down Expand Up @@ -221,7 +223,7 @@ where
if let Some(old_memo) = opt_old_memo {
if old_memo.value.is_some() {
let mut cycle_heads = CycleHeads::default();
if let VerifyResult::Unchanged(_) =
if let VerifyResult::Unchanged { .. } =
self.deep_verify_memo(db, zalsa, old_memo, database_key_index, &mut cycle_heads)
{
if cycle_heads.is_empty() {
Expand Down
Loading