diff --git a/src/cli/rustup_mode.rs b/src/cli/rustup_mode.rs index 5f8d08aca4..d643aafba6 100644 --- a/src/cli/rustup_mode.rs +++ b/src/cli/rustup_mode.rs @@ -802,14 +802,20 @@ async fn check_updates(cfg: &Cfg<'_>, opts: CheckOpts) -> Result 1, + Some(_) | None => channels_len, + }; // Ensure that `.buffered()` is never called with 0 as this will cause a hang. // See: https://github.com/rust-lang/futures-rs/pull/1194#discussion_r209501774 - if num_channels > 0 { + if channels_len > 0 { let multi_progress_bars = MultiProgress::with_draw_target(cfg.process.progress_draw_target()); - let semaphore = Arc::new(Semaphore::new(num_channels)); + let semaphore = Arc::new(Semaphore::new(concurrent_downloads)); let channels = tokio_stream::iter(channels.into_iter()).map(|(name, distributable)| { let pb = multi_progress_bars.add(ProgressBar::new(1)); pb.set_style( @@ -878,7 +884,10 @@ async fn check_updates(cfg: &Cfg<'_>, opts: CheckOpts) -> Result>() .await } else { - channels.buffered(num_channels).collect::>().await + channels + .buffered(concurrent_downloads) + .collect::>() + .await }; let t = cfg.process.stdout().terminal(cfg.process); diff --git a/src/dist/manifestation.rs b/src/dist/manifestation.rs index f1bfa49a81..fd890b1a18 100644 --- a/src/dist/manifestation.rs +++ b/src/dist/manifestation.rs @@ -157,10 +157,12 @@ impl Manifestation { let mut things_downloaded: Vec = Vec::new(); let components = update.components_urls_and_hashes(new_manifest)?; let components_len = components.len(); - let num_channels = download_cfg + + const DEFAULT_CONCURRENT_DOWNLOADS: usize = 2; + let concurrent_downloads = download_cfg .process .concurrent_downloads() - .unwrap_or(components_len); + .unwrap_or(DEFAULT_CONCURRENT_DOWNLOADS); const DEFAULT_MAX_RETRIES: usize = 3; let max_retries: usize = download_cfg @@ -180,7 +182,7 @@ impl Manifestation { )); } - let semaphore = Arc::new(Semaphore::new(num_channels)); + let semaphore = Arc::new(Semaphore::new(concurrent_downloads)); let component_stream = tokio_stream::iter(components.into_iter()).map(|(component, format, url, hash)| { let sem = semaphore.clone(); @@ -200,7 +202,7 @@ impl Manifestation { .await } }); - if num_channels > 0 { + if components_len > 0 { let results = component_stream .buffered(components_len) .collect::>() diff --git a/src/dist/manifestation/tests.rs b/src/dist/manifestation/tests.rs index fcc743ed68..233ccb2df9 100644 --- a/src/dist/manifestation/tests.rs +++ b/src/dist/manifestation/tests.rs @@ -417,16 +417,33 @@ struct TestContext { impl TestContext { fn new(edit: Option<&dyn Fn(&str, &mut MockChannel)>, comps: Compressions) -> Self { + Self::with_env(edit, comps, HashMap::new()) + } + + fn with_env( + edit: Option<&dyn Fn(&str, &mut MockChannel)>, + comps: Compressions, + env: HashMap, + ) -> Self { let dist_tempdir = tempfile::Builder::new().prefix("rustup").tempdir().unwrap(); let mock_dist_server = create_mock_dist_server(dist_tempdir.path(), edit); let url = Url::parse(&format!("file://{}", dist_tempdir.path().to_string_lossy())).unwrap(); - let mut cx = Self::from_dist_server(mock_dist_server, url, comps); + let mut cx = Self::from_dist_server_with_env(mock_dist_server, url, comps, env); cx._tempdirs.push(dist_tempdir); cx } fn from_dist_server(server: MockDistServer, url: Url, comps: Compressions) -> Self { + Self::from_dist_server_with_env(server, url, comps, HashMap::new()) + } + + fn from_dist_server_with_env( + server: MockDistServer, + url: Url, + comps: Compressions, + env: HashMap, + ) -> Self { server.write( &[MockManifestVersion::V2], comps.enable_xz(), @@ -444,12 +461,7 @@ impl TestContext { let toolchain = ToolchainDesc::from_str("nightly-x86_64-apple-darwin").unwrap(); let prefix = InstallPrefix::from(prefix_tempdir.path()); - let tp = TestProcess::new( - env::current_dir().unwrap(), - &["rustup"], - HashMap::default(), - "", - ); + let tp = TestProcess::new(env::current_dir().unwrap(), &["rustup"], env, ""); Self { url, @@ -1301,6 +1313,33 @@ async fn remove_extensions_does_not_remove_other_components() { assert!(utils::path_exists(cx.prefix.path().join("bin/rustc"))); } +#[tokio::test] +async fn remove_extensions_does_not_hang_with_concurrent_downloads_override() { + let cx = TestContext::with_env( + None, + GZOnly, + [("RUSTUP_CONCURRENT_DOWNLOADS".to_owned(), "2".to_owned())].into(), + ); + + let adds = vec![Component::new( + "rust-std".to_string(), + Some(TargetTriple::new("i686-apple-darwin")), + false, + )]; + + cx.update_from_dist(&adds, &[], false).await.unwrap(); + + let removes = vec![Component::new( + "rust-std".to_string(), + Some(TargetTriple::new("i686-apple-darwin")), + false, + )]; + + cx.update_from_dist(&[], &removes, false).await.unwrap(); + + assert!(utils::path_exists(cx.prefix.path().join("bin/rustc"))); +} + #[tokio::test] async fn add_and_remove_for_upgrade() { let cx = TestContext::new(None, GZOnly); diff --git a/src/download/mod.rs b/src/download/mod.rs index e174146136..fa89bb7896 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -1,7 +1,7 @@ //! Easy file downloading use std::fs::remove_file; -use std::num::NonZeroU64; +use std::num::NonZero; use std::path::Path; use std::str::FromStr; use std::time::Duration; @@ -201,9 +201,11 @@ async fn download_file_( }; let timeout = Duration::from_secs(match process.var("RUSTUP_DOWNLOAD_TIMEOUT") { - Ok(s) => NonZeroU64::from_str(&s).context( - "invalid value in RUSTUP_DOWNLOAD_TIMEOUT -- must be a natural number greater than zero", - )?.get(), + Ok(s) => NonZero::from_str(&s) + .context( + "invalid value in RUSTUP_DOWNLOAD_TIMEOUT -- must be a natural number greater than zero", + )? + .get(), Err(_) => 180, }); diff --git a/src/process.rs b/src/process.rs index 69050da77b..706d40cd63 100644 --- a/src/process.rs +++ b/src/process.rs @@ -2,7 +2,7 @@ use std::ffi::OsString; use std::fmt::Debug; use std::io; use std::io::IsTerminal; -use std::num::NonZeroU64; +use std::num::NonZero; use std::path::PathBuf; use std::str::FromStr; #[cfg(feature = "test")] @@ -188,12 +188,8 @@ impl Process { } pub fn concurrent_downloads(&self) -> Option { - match self.var("RUSTUP_CONCURRENT_DOWNLOADS") { - Ok(s) => Some(NonZeroU64::from_str(&s).context( - "invalid value in RUSTUP_CONCURRENT_DOWNLOADS -- must be a natural number greater than zero" - ).ok()?.get() as usize), - Err(_) => None, - } + let s = self.var("RUSTUP_CONCURRENT_DOWNLOADS").ok()?; + Some(NonZero::from_str(&s).ok()?.get()) } } diff --git a/tests/suite/cli_inst_interactive.rs b/tests/suite/cli_inst_interactive.rs index 6af7cc3541..7a82ac187e 100644 --- a/tests/suite/cli_inst_interactive.rs +++ b/tests/suite/cli_inst_interactive.rs @@ -262,6 +262,43 @@ no active toolchain .is_ok(); } +#[tokio::test] +async fn with_no_toolchain_doesnt_hang() { + let cx = CliTestContext::new(Scenario::SimpleV2).await; + run_input( + &cx.config, + &[ + "rustup-init", + "--no-modify-path", + "--default-toolchain=none", + ], + "\n\n", + ) + .is_ok(); + + cx.config.expect(["rustup", "check"]).await.is_err(); +} + +#[tokio::test] +async fn with_no_toolchain_doesnt_hang_with_concurrent_downloads_override() { + let cx = CliTestContext::new(Scenario::SimpleV2).await; + run_input( + &cx.config, + &[ + "rustup-init", + "--no-modify-path", + "--default-toolchain=none", + ], + "\n\n", + ) + .is_ok(); + + cx.config + .expect_with_env(["rustup", "check"], [("RUSTUP_CONCURRENT_DOWNLOADS", "2")]) + .await + .is_err(); +} + #[tokio::test] async fn with_non_default_toolchain_still_prompts() { let cx = CliTestContext::new(Scenario::SimpleV2).await;