Skip to content

Reduce database writes for the download endpoint #3413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ conduit-static = "0.9.0-alpha.3"

cookie = { version = "0.15", features = ["secure"] }
ctrlc = { version = "3.0", features = ["termination"] }
dashmap = { version = "4.0.2", features = ["raw-api"] }
derive_deref = "1.1.1"
dialoguer = "0.7.1"
diesel = { version = "1.4.0", features = ["postgres", "serde_json", "chrono", "r2d2"] }
Expand Down
5 changes: 5 additions & 0 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::{db, Config, Env};
use std::{sync::Arc, time::Duration};

use crate::downloads_counter::DownloadsCounter;
use crate::github::GitHubClient;
use diesel::r2d2;
use oauth2::basic::BasicClient;
Expand Down Expand Up @@ -32,6 +33,9 @@ pub struct App {
/// The server configuration
pub config: Config,

/// Count downloads and periodically persist them in the database
pub downloads_counter: DownloadsCounter,

/// A configured client for outgoing HTTP requests
///
/// In production this shares a single connection pool across requests. In tests
Expand Down Expand Up @@ -131,6 +135,7 @@ impl App {
github_oauth,
session_key: config.session_key.clone(),
config,
downloads_counter: DownloadsCounter::new(),
http_client,
}
}
Expand Down
30 changes: 26 additions & 4 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

let config = cargo_registry::Config::default();
let client = Client::new();
let app = Arc::new(App::new(config.clone(), Some(client)));

let app = App::new(config.clone(), Some(client));
let app = cargo_registry::build_handler(Arc::new(app));
// Start the background thread periodically persisting download counts to the database.
downloads_counter_thread(app.clone());

let handler = cargo_registry::build_handler(app.clone());

// On every server restart, ensure the categories available in the database match
// the information in *src/categories.toml*.
Expand Down Expand Up @@ -100,7 +103,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.build()
.unwrap();

let handler = Arc::new(conduit_hyper::BlockingHandler::new(app));
let handler = Arc::new(conduit_hyper::BlockingHandler::new(handler));
let make_service =
hyper::service::make_service_fn(move |socket: &hyper::server::conn::AddrStream| {
let addr = socket.remote_addr();
Expand Down Expand Up @@ -131,7 +134,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Booting with a civet based server");
let mut cfg = civet::Config::new();
cfg.port(port).threads(threads).keep_alive(true);
Civet(CivetServer::start(cfg, app).unwrap())
Civet(CivetServer::start(cfg, handler).unwrap())
};

println!("listening on port {}", port);
Expand Down Expand Up @@ -164,6 +167,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}

println!("Persisting remaining downloads counters");
if let Err(err) = app.downloads_counter.persist_all_shards(&app) {
println!("downloads_counter error: {}", err);
}

println!("Server has gracefully shutdown!");
Ok(())
}
Expand All @@ -184,3 +192,17 @@ where
})
.unwrap();
}

fn downloads_counter_thread(app: Arc<App>) {
let interval = Duration::from_millis(
(app.config.downloads_persist_interval_ms / app.downloads_counter.shards_count()) as u64,
);

std::thread::spawn(move || loop {
std::thread::sleep(interval);

if let Err(err) = app.downloads_counter.persist_next_shard(&app) {
println!("downloads_counter error: {}", err);
}
});
}
9 changes: 9 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct Config {
pub blocked_traffic: Vec<(String, Vec<String>)>,
pub domain_name: String,
pub allowed_origins: Vec<String>,
pub downloads_persist_interval_ms: usize,
}

impl Default for Config {
Expand All @@ -45,6 +46,7 @@ impl Default for Config {
/// - `READ_ONLY_REPLICA_URL`: The URL of an optional postgres read-only replica database.
/// - `BLOCKED_TRAFFIC`: A list of headers and environment variables to use for blocking
///. traffic. See the `block_traffic` module for more documentation.
/// - `DOWNLOADS_PERSIST_INTERVAL_MS`: how frequent to persist download counts (in ms).
fn default() -> Config {
let api_protocol = String::from("https");
let mirror = if dotenv::var("MIRROR").is_ok() {
Expand Down Expand Up @@ -144,6 +146,13 @@ impl Default for Config {
blocked_traffic: blocked_traffic(),
domain_name: domain_name(),
allowed_origins,
downloads_persist_interval_ms: dotenv::var("DOWNLOADS_PERSIST_INTERVAL_MS")
.map(|interval| {
interval
.parse()
.expect("invalid DOWNLOADS_PERSIST_INTERVAL_MS")
})
.unwrap_or(60_000), // 1 minute
}
}
}
Expand Down
63 changes: 20 additions & 43 deletions src/controllers/version/downloads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,33 @@ pub fn download(req: &mut dyn RequestExt) -> EndpointResult {
let crate_name = &req.params()["crate_id"];
let version = &req.params()["version"];

let (crate_name, was_counted) = increment_download_counts(req, recorder, crate_name, version)?;
let (version_id, crate_name): (_, String) = {
use self::versions::dsl::*;

let conn = recorder.record("get_conn", || req.db_conn())?;

// Returns the crate name as stored in the database, or an error if we could
// not load the version ID from the database.
recorder.record("get_version", || {
versions
.inner_join(crates::table)
.select((id, crates::name))
.filter(Crate::with_name(crate_name))
.filter(num.eq(version))
.first(&*conn)
})?
};

// The increment does not happen instantly, but it's deferred to be executed in a batch
// along with other downloads. See crate::downloads_counter for the implementation.
req.app().downloads_counter.increment(version_id);

let redirect_url = req
.app()
.config
.uploader
.crate_location(&crate_name, version);

// Adding log metadata requires &mut access, so we have to defer this step until
// after the (immutable) query parameters are no longer used.
if !was_counted {
req.log_metadata("uncounted_dl", "true");
}

if req.wants_json() {
#[derive(Serialize)]
struct R {
Expand All @@ -45,42 +58,6 @@ pub fn download(req: &mut dyn RequestExt) -> EndpointResult {
}
}

/// Increment the download counts for a given crate version.
///
/// Returns the crate name as stored in the database, or an error if we could
/// not load the version ID from the database.
///
/// This ignores any errors that occur updating the download count. Failure is
/// expected if the application is in read only mode, or for API-only mirrors.
/// Even if failure occurs for unexpected reasons, we would rather have `cargo
/// build` succeed and not count the download than break people's builds.
fn increment_download_counts(
req: &dyn RequestExt,
recorder: TimingRecorder,
crate_name: &str,
version: &str,
) -> AppResult<(String, bool)> {
use self::versions::dsl::*;

let conn = recorder.record("get_conn", || req.db_conn())?;

let (version_id, crate_name) = recorder.record("get_version", || {
versions
.inner_join(crates::table)
.select((id, crates::name))
.filter(Crate::with_name(crate_name))
.filter(num.eq(version))
.first(&*conn)
})?;

// Wrap in a transaction so we don't poison the outer transaction if this
// fails
let res = recorder.record("update_count", || {
conn.transaction(|| VersionDownload::create_or_increment(version_id, &conn))
});
Ok((crate_name, res.is_ok()))
}

/// Handles the `GET /crates/:crate_id/:version/downloads` route.
pub fn downloads(req: &mut dyn RequestExt) -> EndpointResult {
let (crate_name, semver) = extract_crate_name_and_semver(req)?;
Expand Down
161 changes: 161 additions & 0 deletions src/downloads_counter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use crate::App;
use anyhow::Error;
use dashmap::{DashMap, SharedValue};
use diesel::{pg::upsert::excluded, prelude::*};
use std::collections::HashMap;
use std::sync::atomic::{AtomicI64, AtomicUsize, Ordering};

/// crates.io receives a lot of download requests, and we can't execute a write query to the
/// database during each connection for performance reasons. To reduce the write load, this struct
/// collects the pending updates from the current process and writes in batch.
///
/// To avoid locking the whole data structure behind a RwLock, which could potentially delay
/// requests, this uses the dashmap crate. A DashMap has the same public API as an HashMap, but
/// stores the items into `num_cpus()*4` individually locked shards. This approach reduces the
/// likelyhood of a request encountering a locked shard.
///
/// Persisting the download counts in the database also takes advantage of the inner sharding of
/// DashMaps: to avoid locking all the download requests at the same time each iteration only
/// persists a single shard at the time.
///
/// The disadvantage of this approach is that download counts are stored in memory until they're
/// persisted, so it's possible to lose some of them if the process exits ungracefully. While
/// that's far from ideal, the advantage of batching database updates far outweights potentially
/// losing some download counts.
#[derive(Debug)]
pub struct DownloadsCounter {
/// Inner storage for the download counts.
inner: DashMap<i32, AtomicUsize>,
/// Index of the next shard that should be persisted by `persist_next_shard`.
shard_idx: AtomicUsize,
/// Number of downloads that are not yet persisted on the database. This is just used as a
/// metric included in log lines, and it's not guaranteed to be accurate.
pending_count: AtomicI64,
}

impl DownloadsCounter {
pub(crate) fn new() -> Self {
Self {
inner: DashMap::new(),
shard_idx: AtomicUsize::new(0),
pending_count: AtomicI64::new(0),
}
}

pub(crate) fn increment(&self, version_id: i32) {
self.pending_count.fetch_add(1, Ordering::SeqCst);

if let Some(counter) = self.inner.get(&version_id) {
// The version is already recorded in the DashMap, so we don't need to lock the whole
// shard in write mode. The shard is instead locked in read mode, which allows an
// unbounded number of readers as long as there are no write locks.
counter.value().fetch_add(1, Ordering::SeqCst);
} else {
// The version is not in the DashMap, so we need to lock the whole shard in write mode
// and insert the version into it. This has worse performance than the above case.
self.inner
.entry(version_id)
.and_modify(|counter| {
// Handle the version being inserted by another thread while we were waiting
// for the write lock on the shard.
counter.fetch_add(1, Ordering::SeqCst);
})
.or_insert_with(|| AtomicUsize::new(1));
}
}

pub fn persist_all_shards(&self, app: &App) -> Result<(), Error> {
let conn = app.primary_database.get()?;

let mut counted_downloads = 0;
let mut counted_versions = 0;
let mut pending_downloads = 0;
for shard in self.inner.shards() {
let shard = std::mem::take(&mut *shard.write());
let stats = self.persist_shard(&conn, shard)?;

counted_downloads += stats.counted_downloads;
counted_versions += stats.counted_versions;
pending_downloads = stats.pending_downloads;
}

println!(
"download_counter all_shards counted_versions={} counted_downloads={} pending_downloads={}",
counted_versions,
counted_downloads,
pending_downloads,
);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, is there a reason why we're using println!() everywhere instead of the log or tracing macros?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is really a reason, but I don't see value in switching to another library right now.


Ok(())
}

pub fn persist_next_shard(&self, app: &App) -> Result<(), Error> {
let conn = app.primary_database.get()?;

// Replace the next shard in the ring with an empty HashMap (clearing it), and return the
// previous contents for processing. The fetch_add method wraps around on overflow, so it's
// fine to keep incrementing it without resetting.
let shards = self.inner.shards();
let idx = self.shard_idx.fetch_add(1, Ordering::SeqCst) % shards.len();
let shard = std::mem::take(&mut *shards[idx].write());

let stats = self.persist_shard(&conn, shard)?;
println!(
"download_counter shard={} counted_versions={} counted_downloads={} pending_downloads={}",
idx,
stats.counted_versions,
stats.counted_downloads,
stats.pending_downloads,
);

Ok(())
}

fn persist_shard(
&self,
conn: &PgConnection,
shard: HashMap<i32, SharedValue<AtomicUsize>>,
) -> Result<PersistStats, Error> {
use crate::schema::version_downloads::dsl::*;

let mut counted_downloads = 0;
let mut counted_versions = 0;
let mut to_insert = Vec::new();
for (key, atomic) in shard.iter() {
let count = atomic.get().load(Ordering::SeqCst);
counted_downloads += count;
counted_versions += 1;

to_insert.push((version_id.eq(*key), downloads.eq(count as i32)));
}

if !to_insert.is_empty() {
diesel::insert_into(version_downloads)
.values(&to_insert)
.on_conflict((version_id, date))
.do_update()
.set(downloads.eq(downloads + excluded(downloads)))
.execute(conn)?;
}

let old_pending = self
.pending_count
.fetch_sub(counted_downloads as i64, Ordering::SeqCst);

Ok(PersistStats {
counted_downloads,
counted_versions,
pending_downloads: old_pending - counted_downloads as i64,
})
}

pub fn shards_count(&self) -> usize {
self.inner.shards().len()
}
}

struct PersistStats {
counted_downloads: usize,
counted_versions: usize,
pending_downloads: i64,
}
Loading