Skip to content

Commit 0656eca

Browse files
authored
fix: Delete not re-created tracked structs after fixpoint iteration (#979)
* Fix tracked structs diffing in cycles * Proper fix * Clippy * Add regression test * Discard changes to src/function/maybe_changed_after.rs * Improve comemnt * Suppress clippy error in position where I don't control the types
1 parent 411f844 commit 0656eca

File tree

8 files changed

+169
-37
lines changed

8 files changed

+169
-37
lines changed

examples/calc/db.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ impl CalcDatabaseImpl {
4848
}
4949

5050
#[cfg(test)]
51+
#[allow(unused)]
5152
pub fn take_logs(&self) -> Vec<String> {
5253
let mut logs = self.logs.lock().unwrap();
5354
if let Some(logs) = &mut *logs {

src/active_query.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@ use crate::accumulator::{
55
accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues, InputAccumulatedValues},
66
Accumulator,
77
};
8-
use crate::cycle::{CycleHeads, IterationCount};
9-
use crate::durability::Durability;
108
use crate::hash::FxIndexSet;
119
use crate::key::DatabaseKeyIndex;
1210
use crate::runtime::Stamp;
1311
use crate::sync::atomic::AtomicBool;
1412
use crate::tracked_struct::{Disambiguator, DisambiguatorMap, IdentityHash, IdentityMap};
1513
use crate::zalsa_local::{QueryEdge, QueryOrigin, QueryRevisions, QueryRevisionsExtra};
1614
use crate::Revision;
15+
use crate::{
16+
cycle::{CycleHeads, IterationCount},
17+
Id,
18+
};
19+
use crate::{durability::Durability, tracked_struct::Identity};
1720

1821
#[derive(Debug)]
1922
pub(crate) struct ActiveQuery {
@@ -74,6 +77,7 @@ impl ActiveQuery {
7477
changed_at: Revision,
7578
edges: &[QueryEdge],
7679
untracked_read: bool,
80+
active_tracked_ids: &[(Identity, Id)],
7781
) {
7882
assert!(self.input_outputs.is_empty());
7983

@@ -83,7 +87,8 @@ impl ActiveQuery {
8387
self.untracked_read |= untracked_read;
8488

8589
// Mark all tracked structs from the previous iteration as active.
86-
self.tracked_struct_ids.mark_all_active();
90+
self.tracked_struct_ids
91+
.mark_all_active(active_tracked_ids.iter().copied());
8792
}
8893

8994
pub(super) fn add_read(
@@ -408,7 +413,7 @@ pub(crate) struct CompletedQuery {
408413

409414
/// The keys of any tracked structs that were created in a previous execution of the
410415
/// query but not the current one, and should be marked as stale.
411-
pub(crate) stale_tracked_structs: Vec<DatabaseKeyIndex>,
416+
pub(crate) stale_tracked_structs: Vec<(Identity, Id)>,
412417
}
413418

414419
struct CapturedQuery {

src/function/diff_outputs.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ where
2727

2828
// Note that tracked structs are not stored as direct query outputs, but they are still outputs
2929
// that need to be reported as stale.
30-
for output in &completed_query.stale_tracked_structs {
31-
Self::report_stale_output(zalsa, key, *output);
30+
for (identity, id) in &completed_query.stale_tracked_structs {
31+
let output = DatabaseKeyIndex::new(identity.ingredient_index(), *id);
32+
Self::report_stale_output(zalsa, key, output);
3233
}
3334

3435
let mut stale_outputs = output_edges(edges).collect::<FxIndexSet<_>>();

src/function/execute.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::cycle::{CycleRecoveryStrategy, IterationCount};
33
use crate::function::memo::Memo;
44
use crate::function::{Configuration, IngredientImpl};
55
use crate::sync::atomic::{AtomicBool, Ordering};
6+
use crate::tracked_struct::Identity;
67
use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase};
78
use crate::zalsa_local::ActiveQueryGuard;
89
use crate::{Event, EventKind, Id};
@@ -134,13 +135,25 @@ where
134135
let database_key_index = active_query.database_key_index;
135136
let mut iteration_count = IterationCount::initial();
136137
let mut fell_back = false;
138+
let zalsa_local = db.zalsa_local();
137139

138140
// Our provisional value from the previous iteration, when doing fixpoint iteration.
139141
// Initially it's set to None, because the initial provisional value is created lazily,
140142
// only when a cycle is actually encountered.
141143
let mut opt_last_provisional: Option<&Memo<'db, C>> = None;
144+
let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new();
145+
142146
loop {
143147
let previous_memo = opt_last_provisional.or(opt_old_memo);
148+
149+
// Tracked struct ids that existed in the previous revision
150+
// but weren't recreated in the last iteration. It's important that we seed the next
151+
// query with these ids because the query might re-create them as part of the next iteration.
152+
// This is not only important to ensure that the re-created tracked structs have the same ids,
153+
// it's also important to ensure that these tracked structs get removed
154+
// if they aren't recreated when reaching the final iteration.
155+
active_query.seed_tracked_struct_ids(&last_stale_tracked_ids);
156+
144157
let (mut new_value, mut completed_query) =
145158
Self::execute_query(db, zalsa, active_query, previous_memo, id);
146159

@@ -239,10 +252,9 @@ where
239252
),
240253
memo_ingredient_index,
241254
));
255+
last_stale_tracked_ids = completed_query.stale_tracked_structs;
242256

243-
active_query = db
244-
.zalsa_local()
245-
.push_query(database_key_index, iteration_count);
257+
active_query = zalsa_local.push_query(database_key_index, iteration_count);
246258

247259
continue;
248260
}
@@ -280,9 +292,7 @@ where
280292
if let Some(old_memo) = opt_old_memo {
281293
// If we already executed this query once, then use the tracked-struct ids from the
282294
// previous execution as the starting point for the new one.
283-
if let Some(tracked_struct_ids) = old_memo.revisions.tracked_struct_ids() {
284-
active_query.seed_tracked_struct_ids(tracked_struct_ids);
285-
}
295+
active_query.seed_tracked_struct_ids(old_memo.revisions.tracked_struct_ids());
286296

287297
// Copy over all inputs and outputs from a previous iteration.
288298
// This is necessary to:

src/function/memo.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ where
326326
stale_output.remove_stale_output(zalsa, executor);
327327
}
328328

329-
for (identity, id) in self.revisions.tracked_struct_ids().into_iter().flatten() {
329+
for (identity, id) in self.revisions.tracked_struct_ids() {
330330
let key = DatabaseKeyIndex::new(identity.ingredient_index(), *id);
331331
key.remove_stale_output(zalsa, executor);
332332
}

src/tracked_struct.rs

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -255,19 +255,15 @@ pub(crate) struct IdentityMap {
255255
impl IdentityMap {
256256
/// Seeds the identity map with the IDs from a previous revision.
257257
pub(crate) fn seed(&mut self, source: &[(Identity, Id)]) {
258-
self.table.clear();
259-
self.table
260-
.reserve(source.len(), |entry| entry.identity.hash);
261-
262258
for &(key, id) in source {
263259
self.insert_entry(key, id, false);
264260
}
265261
}
266262

267263
// Mark all tracked structs in the map as created by the current query.
268-
pub(crate) fn mark_all_active(&mut self) {
269-
for entry in self.table.iter_mut() {
270-
entry.active = true;
264+
pub(crate) fn mark_all_active(&mut self, items: impl IntoIterator<Item = (Identity, Id)>) {
265+
for (key, id) in items {
266+
self.insert_entry(key, id, true);
271267
}
272268
}
273269

@@ -330,7 +326,8 @@ impl IdentityMap {
330326
/// The first entry contains the identity and IDs of any tracked structs that were
331327
/// created by the current execution of the query, while the second entry contains any
332328
/// tracked structs that were created in a previous execution but not the current one.
333-
pub(crate) fn drain(&mut self) -> (ThinVec<(Identity, Id)>, Vec<DatabaseKeyIndex>) {
329+
#[expect(clippy::type_complexity)]
330+
pub(crate) fn drain(&mut self) -> (ThinVec<(Identity, Id)>, Vec<(Identity, Id)>) {
334331
if self.table.is_empty() {
335332
return (ThinVec::new(), Vec::new());
336333
}
@@ -342,19 +339,14 @@ impl IdentityMap {
342339
if entry.active {
343340
active.push((entry.identity, entry.id));
344341
} else {
345-
stale.push(DatabaseKeyIndex::new(
346-
entry.identity.ingredient_index(),
347-
entry.id,
348-
));
342+
stale.push((entry.identity, entry.id));
349343
}
350344
}
351345

352346
// Removing a stale tracked struct ID shows up in the event logs, so make sure
353347
// the order is stable here.
354348
stale.sort_unstable_by(|a, b| {
355-
a.ingredient_index()
356-
.cmp(&b.ingredient_index())
357-
.then(a.key_index().cmp(&b.key_index()))
349+
(a.0.ingredient_index(), a.1).cmp(&(b.0.ingredient_index(), b.1))
358350
});
359351

360352
(active, stale)

src/zalsa_local.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -668,13 +668,13 @@ impl QueryRevisions {
668668
}
669669
}
670670

671-
/// Returns a reference to the `IdentityMap` for this query, or `None` if the map is empty.
672-
pub fn tracked_struct_ids(&self) -> Option<&[(Identity, Id)]> {
671+
/// Returns the ids of the tracked structs created when running this query.
672+
pub fn tracked_struct_ids(&self) -> &[(Identity, Id)] {
673673
self.extra
674674
.0
675675
.as_ref()
676676
.map(|extra| &*extra.tracked_struct_ids)
677-
.filter(|tracked_struct_ids| !tracked_struct_ids.is_empty())
677+
.unwrap_or_default()
678678
}
679679

680680
/// Returns a mutable reference to the `IdentityMap` for this query, or `None` if the map is empty.
@@ -1090,7 +1090,6 @@ impl ActiveQueryGuard<'_> {
10901090
#[cfg(debug_assertions)]
10911091
assert_eq!(stack.len(), self.push_len);
10921092
let frame = stack.last_mut().unwrap();
1093-
assert!(frame.tracked_struct_ids().is_empty());
10941093
frame.tracked_struct_ids_mut().seed(tracked_struct_ids);
10951094
})
10961095
}
@@ -1105,14 +1104,15 @@ impl ActiveQueryGuard<'_> {
11051104
previous.origin.as_ref(),
11061105
QueryOriginRef::DerivedUntracked(_)
11071106
);
1107+
let tracked_ids = previous.tracked_struct_ids();
11081108

11091109
// SAFETY: We do not access the query stack reentrantly.
11101110
unsafe {
11111111
self.local_state.with_query_stack_unchecked_mut(|stack| {
11121112
#[cfg(debug_assertions)]
11131113
assert_eq!(stack.len(), self.push_len);
11141114
let frame = stack.last_mut().unwrap();
1115-
frame.seed_iteration(durability, changed_at, edges, untracked_read);
1115+
frame.seed_iteration(durability, changed_at, edges, untracked_read, tracked_ids);
11161116
})
11171117
}
11181118
}

tests/cycle_tracked.rs

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
#![cfg(feature = "inventory")]
22

3-
//! Tests for cycles where the cycle head is stored on a tracked struct
4-
//! and that tracked struct is freed in a later revision.
5-
63
mod common;
74

85
use crate::common::{EventLoggerDatabase, LogDatabase};
@@ -45,6 +42,7 @@ struct Node<'db> {
4542
#[salsa::input(debug)]
4643
struct GraphInput {
4744
simple: bool,
45+
fixpoint_variant: usize,
4846
}
4947

5048
#[salsa::tracked(returns(ref))]
@@ -125,11 +123,13 @@ fn cycle_recover(
125123
CycleRecoveryAction::Iterate
126124
}
127125

126+
/// Tests for cycles where the cycle head is stored on a tracked struct
127+
/// and that tracked struct is freed in a later revision.
128128
#[test]
129129
fn main() {
130130
let mut db = EventLoggerDatabase::default();
131131

132-
let input = GraphInput::new(&db, false);
132+
let input = GraphInput::new(&db, false, 0);
133133
let graph = create_graph(&db, input);
134134
let c = graph.find_node(&db, "c").unwrap();
135135

@@ -192,3 +192,126 @@ fn main() {
192192
"WillCheckCancellation",
193193
]"#]]);
194194
}
195+
196+
#[salsa::tracked]
197+
struct IterationNode<'db> {
198+
#[returns(ref)]
199+
name: String,
200+
iteration: usize,
201+
}
202+
203+
/// A cyclic query that creates more tracked structs in later fixpoint iterations.
204+
///
205+
/// The output depends on the input's fixpoint_variant:
206+
/// - variant=0: Returns `[base]` (1 struct, no cycle)
207+
/// - variant=1: Through fixpoint iteration, returns `[iter_0, iter_1, iter_2]` (3 structs)
208+
/// - variant=2: Through fixpoint iteration, returns `[iter_0, iter_1]` (2 structs)
209+
/// - variant>2: Through fixpoint iteration, returns `[iter_0, iter_1]` (2 structs, same as variant=2)
210+
///
211+
/// When variant > 0, the query creates a cycle by calling itself. The fixpoint iteration
212+
/// proceeds as follows:
213+
/// 1. Initial: returns empty vector
214+
/// 2. First iteration: returns `[iter_0]`
215+
/// 3. Second iteration: returns `[iter_0, iter_1]`
216+
/// 4. Third iteration (only for variant=1): returns `[iter_0, iter_1, iter_2]`
217+
/// 5. Further iterations: no change, fixpoint reached
218+
#[salsa::tracked(cycle_fn=cycle_recover_with_structs, cycle_initial=initial_with_structs)]
219+
fn create_tracked_in_cycle<'db>(
220+
db: &'db dyn Database,
221+
input: GraphInput,
222+
) -> Vec<IterationNode<'db>> {
223+
// Check if we should create more nodes based on the input.
224+
let variant = input.fixpoint_variant(db);
225+
226+
if variant == 0 {
227+
// Base case - no cycle, just return a single node.
228+
vec![IterationNode::new(db, "base".to_string(), 0)]
229+
} else {
230+
// Create a cycle by calling ourselves.
231+
let previous = create_tracked_in_cycle(db, input);
232+
233+
// In later iterations, create additional tracked structs.
234+
if previous.is_empty() {
235+
// First iteration - initial returns empty.
236+
vec![IterationNode::new(db, "iter_0".to_string(), 0)]
237+
} else {
238+
// Limit based on variant: variant=1 allows 3 nodes, variant=2 allows 2 nodes.
239+
let limit = if variant == 1 { 3 } else { 2 };
240+
241+
if previous.len() < limit {
242+
// Subsequent iterations - add more nodes.
243+
let mut nodes = previous;
244+
nodes.push(IterationNode::new(
245+
db,
246+
format!("iter_{}", nodes.len()),
247+
nodes.len(),
248+
));
249+
nodes
250+
} else {
251+
// Reached the limit.
252+
previous
253+
}
254+
}
255+
}
256+
}
257+
258+
fn initial_with_structs(_db: &dyn Database, _input: GraphInput) -> Vec<IterationNode<'_>> {
259+
vec![]
260+
}
261+
262+
#[allow(clippy::ptr_arg)]
263+
fn cycle_recover_with_structs<'db>(
264+
_db: &'db dyn Database,
265+
_value: &Vec<IterationNode<'db>>,
266+
_iteration: u32,
267+
_input: GraphInput,
268+
) -> CycleRecoveryAction<Vec<IterationNode<'db>>> {
269+
CycleRecoveryAction::Iterate
270+
}
271+
272+
#[test]
273+
fn test_cycle_with_fixpoint_structs() {
274+
let mut db = EventLoggerDatabase::default();
275+
276+
// Create an input that will trigger the cyclic behavior.
277+
let input = GraphInput::new(&db, false, 1);
278+
279+
// Initial query - this will create structs across multiple iterations.
280+
let nodes = create_tracked_in_cycle(&db, input);
281+
assert_eq!(nodes.len(), 3);
282+
// First iteration: previous is empty [], so we get [iter_0]
283+
// Second iteration: previous is [iter_0], so we get [iter_0, iter_1]
284+
// Third iteration: previous is [iter_0, iter_1], so we get [iter_0, iter_1, iter_2]
285+
assert_eq!(nodes[0].name(&db), "iter_0");
286+
assert_eq!(nodes[1].name(&db), "iter_1");
287+
assert_eq!(nodes[2].name(&db), "iter_2");
288+
289+
// Clear logs to focus on the change.
290+
db.clear_logs();
291+
292+
// Change the input to force re-execution with a different variant.
293+
// This will create 2 tracked structs instead of 3 (one fewer than before).
294+
input.set_fixpoint_variant(&mut db).to(2);
295+
296+
// Re-query - this should handle the tracked struct changes properly.
297+
let nodes = create_tracked_in_cycle(&db, input);
298+
assert_eq!(nodes.len(), 2);
299+
assert_eq!(nodes[0].name(&db), "iter_0");
300+
assert_eq!(nodes[1].name(&db), "iter_1");
301+
302+
// Check the logs to ensure proper execution and struct management.
303+
// We should see the third struct (iter_2) being discarded.
304+
db.assert_logs(expect![[r#"
305+
[
306+
"DidSetCancellationFlag",
307+
"WillCheckCancellation",
308+
"WillExecute { database_key: create_tracked_in_cycle(Id(0)) }",
309+
"WillCheckCancellation",
310+
"WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(1), fell_back: false }",
311+
"WillCheckCancellation",
312+
"WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(2), fell_back: false }",
313+
"WillCheckCancellation",
314+
"WillDiscardStaleOutput { execute_key: create_tracked_in_cycle(Id(0)), output_key: IterationNode(Id(402)) }",
315+
"DidDiscard { key: IterationNode(Id(402)) }",
316+
]"#]]);
317+
}

0 commit comments

Comments
 (0)