diff --git a/src/bin/server.rs b/src/bin/server.rs index c9fcb86d485..17cd146339b 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -172,8 +172,9 @@ fn main() -> Result<(), Box> { } println!("Persisting remaining downloads counters"); - if let Err(err) = app.downloads_counter.persist_all_shards(&app) { - println!("downloads_counter error: {}", err); + match app.downloads_counter.persist_all_shards(&app) { + Ok(stats) => stats.log(), + Err(err) => println!("downloads_counter error: {}", err), } println!("Server has gracefully shutdown!"); @@ -205,8 +206,9 @@ fn downloads_counter_thread(app: Arc) { std::thread::spawn(move || loop { std::thread::sleep(interval); - if let Err(err) = app.downloads_counter.persist_next_shard(&app) { - println!("downloads_counter error: {}", err); + match app.downloads_counter.persist_next_shard(&app) { + Ok(stats) => stats.log(), + Err(err) => println!("downloads_counter error: {}", err), } }); } diff --git a/src/db.rs b/src/db.rs index 00e04de8ee0..e6b298eaed8 100644 --- a/src/db.rs +++ b/src/db.rs @@ -130,3 +130,10 @@ impl CustomizeConnection for ConnectionConfig { Ok(()) } } + +#[cfg(test)] +pub(crate) fn test_conn() -> PgConnection { + let conn = PgConnection::establish(&crate::env("TEST_DATABASE_URL")).unwrap(); + conn.begin_test_transaction().unwrap(); + conn +} diff --git a/src/downloads_counter.rs b/src/downloads_counter.rs index a617da59bda..3b9b4808681 100644 --- a/src/downloads_counter.rs +++ b/src/downloads_counter.rs @@ -2,7 +2,7 @@ use crate::App; use anyhow::Error; use dashmap::{DashMap, SharedValue}; use diesel::{pg::upsert::excluded, prelude::*}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::atomic::{AtomicI64, AtomicUsize, Ordering}; /// crates.io receives a lot of download requests, and we can't execute a write query to the @@ -64,34 +64,27 @@ impl DownloadsCounter { } } - pub fn persist_all_shards(&self, app: &App) -> Result<(), Error> { + pub fn persist_all_shards(&self, app: &App) -> Result { let conn = app.primary_database.get()?; + self.persist_all_shards_with_conn(&conn) + } - let mut counted_downloads = 0; - let mut counted_versions = 0; - let mut pending_downloads = 0; + pub fn persist_next_shard(&self, app: &App) -> Result { + let conn = app.primary_database.get()?; + self.persist_next_shard_with_conn(&conn) + } + + fn persist_all_shards_with_conn(&self, conn: &PgConnection) -> Result { + let mut stats = PersistStats::default(); for shard in self.inner.shards() { let shard = std::mem::take(&mut *shard.write()); - let stats = self.persist_shard(&conn, shard)?; - - counted_downloads += stats.counted_downloads; - counted_versions += stats.counted_versions; - pending_downloads = stats.pending_downloads; + stats = stats.merge(self.persist_shard(&conn, shard)?); } - println!( - "downloads_counter all_shards counted_versions={} counted_downloads={} pending_downloads={}", - counted_versions, - counted_downloads, - pending_downloads, - ); - - Ok(()) + Ok(stats) } - pub fn persist_next_shard(&self, app: &App) -> Result<(), Error> { - let conn = app.primary_database.get()?; - + fn persist_next_shard_with_conn(&self, conn: &PgConnection) -> Result { // Replace the next shard in the ring with an empty HashMap (clearing it), and return the // previous contents for processing. The fetch_add method wraps around on overflow, so it's // fine to keep incrementing it without resetting. @@ -99,16 +92,9 @@ impl DownloadsCounter { let idx = self.shard_idx.fetch_add(1, Ordering::SeqCst) % shards.len(); let shard = std::mem::take(&mut *shards[idx].write()); - let stats = self.persist_shard(&conn, shard)?; - println!( - "downloads_counter shard={} counted_versions={} counted_downloads={} pending_downloads={}", - idx, - stats.counted_versions, - stats.counted_downloads, - stats.pending_downloads, - ); - - Ok(()) + let mut stats = self.persist_shard(&conn, shard)?; + stats.shard = Some(idx); + Ok(stats) } fn persist_shard( @@ -116,18 +102,16 @@ impl DownloadsCounter { conn: &PgConnection, shard: HashMap>, ) -> Result { - use crate::schema::version_downloads::dsl::*; + use crate::schema::{version_downloads, versions}; + let mut discarded_downloads = 0; let mut counted_downloads = 0; let mut counted_versions = 0; - let mut to_insert = Vec::new(); - for (key, atomic) in shard.iter() { - let count = atomic.get().load(Ordering::SeqCst); - counted_downloads += count; - counted_versions += 1; - to_insert.push((*key, count)); - } + let mut to_insert = shard + .iter() + .map(|(id, atomic)| (*id, atomic.get().load(Ordering::SeqCst))) + .collect::>(); if !to_insert.is_empty() { // The rows we're about to insert need to be sorted to avoid deadlocks when multiple @@ -146,24 +130,66 @@ impl DownloadsCounter { // to_insert.sort_by_key(|(key, _)| *key); - let to_insert = to_insert + // Our database schema enforces that every row in the `version_downloads` table points + // to a valid version in the `versions` table with a foreign key. This doesn't cause + // problems most of the times, as the rest of the application checks whether the + // version exists before calling the `increment` method. + // + // On rare occasions crates are deleted from crates.io though, and that would break the + // invariant if the crate is deleted after the `increment` method is called but before + // the downloads are persisted in the database. + // + // That happening would cause the whole `INSERT` to fail, also losing the downloads in + // the shard we were about to persist. To avoid that from happening this snippet does a + // `SELECT` query on the version table before persisting to check whether every version + // still exists in the database. Missing versions are removed from the following query. + let version_ids = to_insert.iter().map(|(id, _)| *id).collect::>(); + let existing_version_ids: HashSet = versions::table + .select(versions::id) + // `FOR SHARE` prevents updates or deletions on the selected rows in the `versions` + // table until this transaction commits. That prevents a version from being deleted + // between this query and the next one. + // + // `FOR SHARE` is used instead of `FOR UPDATE` to allow rows to be locked by + // multiple `SELECT` transactions, to allow for concurrent downloads persisting. + .for_share() + .filter(versions::id.eq_any(version_ids)) + .load(conn)? .into_iter() - .map(|(key, count)| (version_id.eq(key), downloads.eq(count as i32))) - .collect::>(); + .collect(); + + let mut values = Vec::new(); + for (id, count) in &to_insert { + if !existing_version_ids.contains(id) { + discarded_downloads += *count; + continue; + } + counted_versions += 1; + counted_downloads += *count; + values.push(( + version_downloads::version_id.eq(*id), + version_downloads::downloads.eq(*count as i32), + )); + } - diesel::insert_into(version_downloads) - .values(&to_insert) - .on_conflict((version_id, date)) + diesel::insert_into(version_downloads::table) + .values(&values) + .on_conflict((version_downloads::version_id, version_downloads::date)) .do_update() - .set(downloads.eq(downloads + excluded(downloads))) + .set( + version_downloads::downloads + .eq(version_downloads::downloads + excluded(version_downloads::downloads)), + ) .execute(conn)?; } - let old_pending = self - .pending_count - .fetch_sub(counted_downloads as i64, Ordering::SeqCst); + let old_pending = self.pending_count.fetch_sub( + (counted_downloads + discarded_downloads) as i64, + Ordering::SeqCst, + ); Ok(PersistStats { + shard: None, counted_downloads, counted_versions, pending_downloads: old_pending - counted_downloads as i64, @@ -175,8 +201,252 @@ impl DownloadsCounter { } } -struct PersistStats { +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] +pub struct PersistStats { + shard: Option, counted_downloads: usize, counted_versions: usize, pending_downloads: i64, } + +impl PersistStats { + fn merge(self, other: PersistStats) -> Self { + Self { + shard: if self.shard == other.shard { + other.shard + } else { + None + }, + counted_downloads: self.counted_downloads + other.counted_downloads, + counted_versions: self.counted_versions + other.counted_versions, + pending_downloads: other.pending_downloads, + } + } + + pub fn log(&self) { + println!( + "downloads_counter shard={} counted_versions={} counted_downloads={} pending_downloads={}", + self.shard.map(|s| s.to_string()).unwrap_or_else(|| "all".into()), + self.counted_versions, + self.counted_downloads, + self.pending_downloads, + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::models::{Crate, NewCrate, NewUser, NewVersion, User}; + use diesel::PgConnection; + use semver::Version; + + #[test] + fn test_increment_and_persist_all() { + let counter = DownloadsCounter::new(); + let conn = crate::db::test_conn(); + let mut state = State::new(&conn); + + let v1 = state.new_version(&conn); + let v2 = state.new_version(&conn); + let v3 = state.new_version(&conn); + + // Add 15 downloads between v1 and v2, and no downloads for v3. + for _ in 0..10 { + counter.increment(v1); + } + for _ in 0..5 { + counter.increment(v2); + } + assert_eq!(15, counter.pending_count.load(Ordering::SeqCst)); + + // Persist everything to the database + let stats = counter + .persist_all_shards_with_conn(&conn) + .expect("failed to persist all shards"); + + // Ensure the stats are accurate + assert_eq!( + stats, + PersistStats { + shard: None, + counted_downloads: 15, + counted_versions: 2, + pending_downloads: 0, + } + ); + + // Ensure the download counts in the database are what we expect. + state.assert_downloads_count(&conn, v1, 10); + state.assert_downloads_count(&conn, v2, 5); + state.assert_downloads_count(&conn, v3, 0); + } + + #[test] + fn test_increment_and_persist_shard() { + let counter = DownloadsCounter::new(); + let conn = crate::db::test_conn(); + let mut state = State::new(&conn); + + let v1 = state.new_version(&conn); + let v1_shard = counter.inner.determine_map(&v1); + + // For this test to work we need the two versions to be stored in different shards. + let mut v2 = state.new_version(&conn); + while counter.inner.determine_map(&v2) == v1_shard { + v2 = state.new_version(&conn); + } + let v2_shard = counter.inner.determine_map(&v2); + + // Add 15 downloads between v1 and v2. + for _ in 0..10 { + counter.increment(v1); + } + for _ in 0..5 { + counter.increment(v2); + } + assert_eq!(15, counter.pending_count.load(Ordering::SeqCst)); + + // Persist one shard at the time and ensure the stats returned for each shard are expected. + let mut pending = 15; + for shard in 0..counter.shards_count() { + let stats = counter + .persist_next_shard_with_conn(&conn) + .expect("failed to persist shard"); + + if shard == v1_shard { + pending -= 10; + assert_eq!( + stats, + PersistStats { + shard: Some(shard), + counted_downloads: 10, + counted_versions: 1, + pending_downloads: pending, + } + ); + state.assert_downloads_count(&conn, v1, 10); + } else if shard == v2_shard { + pending -= 5; + assert_eq!( + stats, + PersistStats { + shard: Some(shard), + counted_downloads: 5, + counted_versions: 1, + pending_downloads: pending, + } + ); + state.assert_downloads_count(&conn, v2, 5); + } else { + assert_eq!( + stats, + PersistStats { + shard: Some(shard), + counted_downloads: 0, + counted_versions: 0, + pending_downloads: pending, + } + ); + }; + } + assert_eq!(pending, 0); + + // Finally ensure that the download counts in the database are what we expect. + state.assert_downloads_count(&conn, v1, 10); + state.assert_downloads_count(&conn, v2, 5); + } + + #[test] + fn test_increment_missing_version() { + let counter = DownloadsCounter::new(); + let conn = crate::db::test_conn(); + let mut state = State::new(&conn); + + let v1 = state.new_version(&conn); + let v2 = v1 + 1; // Should not exist in the database! + + // No error should happen when calling the increment method on a missing version. + counter.increment(v1); + counter.increment(v2); + + // No error should happen when persisting. The missing versions should be ignored. + let stats = counter + .persist_all_shards_with_conn(&conn) + .expect("failed to persist download counts"); + + // The download should not be counted for version 2. + assert_eq!( + stats, + PersistStats { + shard: None, + counted_downloads: 1, + counted_versions: 1, + pending_downloads: 0, + } + ); + state.assert_downloads_count(&conn, v1, 1); + state.assert_downloads_count(&conn, v2, 0); + } + + struct State { + user: User, + krate: Crate, + next_version: u32, + } + + impl State { + fn new(conn: &PgConnection) -> Self { + let user = NewUser { + gh_id: 0, + gh_login: "ghost", + ..NewUser::default() + } + .create_or_update(None, conn) + .expect("failed to create user"); + + let krate = NewCrate { + name: "foo", + ..NewCrate::default() + } + .create_or_update(conn, user.id, None) + .expect("failed to create crate"); + + Self { + user, + krate, + next_version: 1, + } + } + + fn new_version(&mut self, conn: &PgConnection) -> i32 { + let version = NewVersion::new( + self.krate.id, + &Version::parse(&format!("{}.0.0", self.next_version)).unwrap(), + &HashMap::new(), + None, + None, + 0, + self.user.id, + ) + .expect("failed to create version") + .save(conn, &[], "ghost@example.com") + .expect("failed to save version"); + + self.next_version += 1; + version.id + } + + fn assert_downloads_count(&self, conn: &PgConnection, version: i32, expected: i64) { + use crate::schema::version_downloads::dsl::*; + use diesel::dsl::*; + + let actual: Option = version_downloads + .select(sum(downloads)) + .filter(version_id.eq(version)) + .first(conn) + .unwrap(); + assert_eq!(actual.unwrap_or(0), expected); + } + } +} diff --git a/src/tasks/update_downloads.rs b/src/tasks/update_downloads.rs index 1cd0d3a3704..dbc7cf6b5c1 100644 --- a/src/tasks/update_downloads.rs +++ b/src/tasks/update_downloads.rs @@ -83,18 +83,9 @@ fn collect(conn: &PgConnection, rows: &[VersionDownload]) -> QueryResult<()> { #[cfg(test)] mod test { use super::*; - use crate::{ - env, - models::{Crate, NewCrate, NewUser, NewVersion, User, Version}, - }; + use crate::models::{Crate, NewCrate, NewUser, NewVersion, User, Version}; use std::collections::HashMap; - fn conn() -> PgConnection { - let conn = PgConnection::establish(&env("TEST_DATABASE_URL")).unwrap(); - conn.begin_test_transaction().unwrap(); - conn - } - fn user(conn: &PgConnection) -> User { NewUser::new(2, "login", None, None, "access_token") .create_or_update(None, conn) @@ -126,7 +117,7 @@ mod test { fn increment() { use diesel::dsl::*; - let conn = conn(); + let conn = crate::db::test_conn(); let user = user(&conn); let (krate, version) = crate_and_version(&conn, user.id); insert_into(version_downloads::table) @@ -165,7 +156,7 @@ mod test { fn set_processed_true() { use diesel::dsl::*; - let conn = conn(); + let conn = crate::db::test_conn(); let user = user(&conn); let (_, version) = crate_and_version(&conn, user.id); insert_into(version_downloads::table) @@ -189,7 +180,7 @@ mod test { #[test] fn dont_process_recent_row() { use diesel::dsl::*; - let conn = conn(); + let conn = crate::db::test_conn(); let user = user(&conn); let (_, version) = crate_and_version(&conn, user.id); insert_into(version_downloads::table) @@ -215,7 +206,7 @@ mod test { use diesel::dsl::*; use diesel::update; - let conn = conn(); + let conn = crate::db::test_conn(); let user = user(&conn); let (krate, version) = crate_and_version(&conn, user.id); update(versions::table) @@ -269,7 +260,7 @@ mod test { use diesel::dsl::*; use diesel::update; - let conn = conn(); + let conn = crate::db::test_conn(); let user = user(&conn); let (_, version) = crate_and_version(&conn, user.id); update(versions::table) diff --git a/src/tests/krate/downloads.rs b/src/tests/krate/downloads.rs index 0a326baa6f4..b45e6c708de 100644 --- a/src/tests/krate/downloads.rs +++ b/src/tests/krate/downloads.rs @@ -46,7 +46,8 @@ fn download() { app.as_inner() .downloads_counter .persist_all_shards(app.as_inner()) - .expect("failed to persist downloads count"); + .expect("failed to persist downloads count") + .log(); }; download("foo_download/1.0.0");