From ca497b16f58c9c5b31e24e3e68b4b46f4cccff5a Mon Sep 17 00:00:00 2001 From: Georg Hopp Date: Wed, 15 Jan 2025 10:33:46 +0100 Subject: [PATCH] Add reconnect ability when all remaining requests got a 503 --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/client.rs | 44 +++++++++++++++++++++++++++++++++++-- src/client_actor.rs | 12 +++++++++- src/client_actor/message.rs | 1 + src/m3u8_download.rs | 13 ++++++++--- src/main.rs | 9 +++++++- 7 files changed, 74 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e620093..fe64627 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -509,7 +509,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hlsget" -version = "0.3.1" +version = "0.3.2" dependencies = [ "anyhow", "bytes", diff --git a/Cargo.toml b/Cargo.toml index c0fcb41..198e559 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hlsget" -version = "0.3.1" +version = "0.3.2" edition = "2021" [dependencies] diff --git a/src/client.rs b/src/client.rs index 1c7720b..51cac4d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -17,7 +17,7 @@ use http::{ Uri, }; use http_body_util::BodyDataStream; -use log::debug; +use log::{debug, error}; use reqwest::{redirect::Policy, Body}; use tokio::{fs::File, io::AsyncWriteExt as _, time::timeout}; use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt as _}; @@ -47,6 +47,10 @@ pub(super) enum DownloadState { #[derive(Clone, Debug)] pub(super) struct Client { + buffer: usize, + rate_limit: u64, + concurrency_limit: usize, + timeout: Duration, client: HttpClient, default_headers: HeaderMap, body_timeout: Option, @@ -90,7 +94,43 @@ impl Client { HeaderValue::from_str(&( crate_name!().to_string() + "/" + crate_version!() )).unwrap() ); - Ok(Self {client, default_headers, body_timeout}) + Ok(Self { buffer + , rate_limit + , concurrency_limit + , timeout + , client + , default_headers + , body_timeout }) + } + + fn build_http_client( buffer: usize + , rate_limit: u64 + , concurrency_limit: usize + , timeout: Duration ) -> Result { + Ok( ServiceBuilder::new() + // Add some layers. + . buffer(buffer) + . rate_limit(rate_limit, Duration::from_secs(1)) + . concurrency_limit(concurrency_limit) + . timeout(timeout) + . layer(DecompressionLayer::new()) + // Make client compatible with the `tower-http` layers. + . layer(HttpClientLayer) + . service( reqwest::Client::builder() + . redirect(Policy::limited(5)) + . build()? ) + . map_err(anyhow::Error::msg) + . boxed_clone() ) + } + + pub(super) fn rebuild_http_client(&mut self) { + match Self::build_http_client( self.buffer + , self.rate_limit + , self.concurrency_limit + , self.timeout ) { + Ok(client) => self.client = client, + Err(e) => error!("unable to re-create client: {}", e), + }; } pub(super) fn set_body_timeout(mut self, timeout: Option) -> Self { diff --git a/src/client_actor.rs b/src/client_actor.rs index a12bfdf..b39925d 100644 --- a/src/client_actor.rs +++ b/src/client_actor.rs @@ -86,6 +86,7 @@ impl ClientActor { ClientActorMessage::GetData { respond_to, .. } => { let _ = respond_to.send(handle.buffer_ref().clone()); }, + _ => (), } } @@ -99,15 +100,18 @@ impl ClientActor { ClientActorMessage::GetData { respond_to, .. } => { let _ = respond_to.send(handle.buffer_ref().clone()); }, + _ => (), } } async fn handle_message(&mut self, message: ClientActorMessage) { self.actions.insert(self.actions_idx, message); - use ClientActorMessage::{Download, GetData}; + use ClientActorMessage::{Download, GetData, Reconnect}; match self.actions.get(&self.actions_idx) { + Some(Reconnect) => self.client.rebuild_http_client(), + Some(Download { ref filename, ref uri, .. }) => { // spawn a task that does the work let mut client = self.client.clone(); @@ -202,4 +206,10 @@ impl ClientActorHandle { let _ = self.sender.send(msg).await; receive.await.expect("Actor cancelled unexpected") } + + pub(super) async fn reconnect(&self) { + let msg = ClientActorMessage::Reconnect; + + let _ = self.sender.send(msg).await; + } } diff --git a/src/client_actor/message.rs b/src/client_actor/message.rs index 4f938bc..5ef6cc0 100644 --- a/src/client_actor/message.rs +++ b/src/client_actor/message.rs @@ -21,6 +21,7 @@ pub(super) enum ClientActorMessage { uri: Uri, respond_to: oneshot::Sender>, }, + Reconnect, } #[derive(Clone, Debug)] diff --git a/src/m3u8_download.rs b/src/m3u8_download.rs index b43cbe1..99d590b 100644 --- a/src/m3u8_download.rs +++ b/src/m3u8_download.rs @@ -30,7 +30,8 @@ struct TsPart { pub(super) struct M3u8Download { index_uri: Uri, ts_parts: Vec, - time_wait: Duration + time_wait: Duration, + do_reconnect: bool, } @@ -75,7 +76,8 @@ impl TsPart { impl M3u8Download { pub(super) async fn new( m3u8_data: Bytes , index_uri: Uri - , time_wait: Duration ) -> anyhow::Result { + , time_wait: Duration + , do_reconnect: bool ) -> anyhow::Result { let scheme = index_uri.scheme() . ok_or(anyhow!("Problem scheme in m3u8 uri"))? . to_owned(); @@ -105,7 +107,7 @@ impl M3u8Download { }, }; - Ok(Self {index_uri, ts_parts, time_wait}) + Ok(Self {index_uri, ts_parts, time_wait, do_reconnect}) } pub(super) fn index_uri(&self) -> &Uri { @@ -146,6 +148,11 @@ impl M3u8Download { break } else { info!("All {} tasks wait for unavailable service", waits.len()); + if self.do_reconnect { + info!("reconnect http connections"); + client.reconnect().await; + } + let pause_time = waits . into_iter() . fold(self.time_wait, |a, w| w.min(a)); diff --git a/src/main.rs b/src/main.rs index 04fbd6d..2ec4c2f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -50,6 +50,10 @@ struct Args { , default_value_t = false , help = "also use timeout on body reads" )] use_body_timeout: bool, + #[arg( short = 'R' + , default_value_t = false + , help = "prevent reconnect before wait" )] + prevent_reconnect: bool, #[arg( short , default_value_t = 301 , help = "wait for temporary failure like 503" )] @@ -120,7 +124,10 @@ async fn main() -> anyhow::Result<()> { let m3u8_data = actor.body_bytes(&m3u8_uri).await . ok_or(anyhow!("Unable to get body for: {}", m3u8_uri))?; - let mut download = M3u8Download::new(m3u8_data, m3u8_uri, wait_time).await?; + let mut download = M3u8Download::new( m3u8_data + , m3u8_uri + , wait_time + , ! args.prevent_reconnect ).await?; info!("Sending concurrent requests...");