From 7738c3699e53444f060ac7a350cf044915efb042 Mon Sep 17 00:00:00 2001 From: Georg Hopp Date: Fri, 10 Jan 2025 20:41:47 +0100 Subject: [PATCH] improve cli and limit in flight requests --- Cargo.toml | 2 +- src/client.rs | 216 +++++++++++++++++++++++++++++++++-------- src/client/data.rs | 2 +- src/client/download.rs | 27 ++++-- src/client/error.rs | 52 ++++++++++ src/client/util.rs | 45 +++++---- src/main.rs | 17 +++- 7 files changed, 285 insertions(+), 76 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 908a8c7..856c85b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] anyhow = "1.0" -clap = { version = "4.5", features = [ "derive" ] } +clap = { version = "4.5", features = [ "derive", "cargo" ] } env_logger = "0.11" futures-util = "0.3" http = "1.2" diff --git a/src/client.rs b/src/client.rs index 71edef2..4888055 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,11 +6,15 @@ mod message; use std::{collections::HashMap, path::Path, time::Duration}; +use clap::{crate_name, crate_version}; use error::ClientError; use http::{ + header::{ORIGIN, USER_AGENT}, + HeaderMap, + HeaderValue, Request, Response, - Uri + Uri, }; use message::{ClientActorMessage, ClientActorMessageHandle}; use reqwest::{redirect::Policy, Body}; @@ -26,25 +30,6 @@ use tower_reqwest::HttpClientLayer; use log::{debug, error, info}; -#[macro_export] -macro_rules! mk_dlerror { - ($message:ident, $($err:tt)*) => {{ - use $crate::client::error; - error::ClientError::new( $message.clone() - , Some(anyhow::anyhow!($($err)*) )) - }}; -} - -#[macro_export] -macro_rules! map_dlerror { - ($message:ident) => {{ - use $crate::client::error; - |e| error::ClientError::new( $message.clone() - , Some(anyhow::anyhow!(e) )) - }}; -} - - #[derive(Clone, Debug)] pub enum DownloadState { GotHead, @@ -62,13 +47,15 @@ type HttpClient = BoxCloneService, Response, client: HttpClient, tasks: JoinSet, actions: HashMap, actions_idx: ActionIndex, + default_headers: HeaderMap, receiver: mpsc::Receiver, + tasks_left: usize, } @@ -79,8 +66,8 @@ pub(super) struct ClientActorHandle { async fn run_client(mut actor: ClientActor) { loop { - select! { - Some(join) = actor.tasks.join_next() => { + if actor.tasks_left == 0 { + if let Some(join) = actor.tasks.join_next().await { use ClientActorMessageHandle::{Download, GetData}; match join { @@ -94,10 +81,15 @@ async fn run_client(mut actor: ClientActor) { // retry ... instead of responing here we could also respond // with something that in turn would be used to retry... let client = actor.client.clone(); + let default_headers = actor.default_headers.clone(); actor.tasks.spawn(async move { match e.action { - Download { .. } => - download::download(client, e.action, actor.timeout).await, + Download { .. } => { + download::download( client + , default_headers + , e.action + , actor.body_timeout ).await + }, GetData { .. } => data::data(client, e.action).await, } @@ -106,6 +98,7 @@ async fn run_client(mut actor: ClientActor) { // when the task finishes Ok(Ok(Some(action))) => { + actor.tasks_left += 1; match action { Download { filename: _, ref uri, state, ref message } => { info!("Done download: {:?}", uri); @@ -144,13 +137,87 @@ async fn run_client(mut actor: ClientActor) { // JoinSet is empty. Ok(Ok(None)) => (), }; - } + }; + } else { + select! { + Some(join) = actor.tasks.join_next() => { + use ClientActorMessageHandle::{Download, GetData}; + + match join { + Err(e) => { + error!("FATAL Join failed: {}", e); + break + }, + + Ok(Err(e)) => { + info!("Retry failed download: {:?}", e); + // retry ... instead of responing here we could also respond + // with something that in turn would be used to retry... + let client = actor.client.clone(); + let default_headers = actor.default_headers.clone(); + actor.tasks.spawn(async move { + match e.action { + Download { .. } => + download::download( client + , default_headers + , e.action + , actor.body_timeout ).await, + GetData { .. } => + data::data(client, e.action).await, + } + }); + }, + + // when the task finishes + Ok(Ok(Some(action))) => { + actor.tasks_left += 1; + match action { + Download { filename: _, ref uri, state, ref message } => { + info!("Done download: {:?}", uri); + if let Some((_, message)) = actor.actions.remove_entry(message) { + use ClientActorMessage::Download; + + match message { + Download { respond_to, .. } => { + let _ = respond_to.send(Ok(state)); + }, + _ => panic!("Wrong variant ... this should never happen"), + } + } else { + panic!("Lost a message"); + } + }, + + GetData { ref uri, buffer, ref message } => { + info!("Done get_data: {:?}", uri); + if let Some((_, message)) = actor.actions.remove_entry(message) { + use ClientActorMessage::GetData; + match message { + GetData { respond_to, .. } => { + let _ = respond_to.send(buffer); + }, + _ => panic!("Wrong variant ... this should never happen"), + } + } else { + panic!("Lost a message"); + } + }, + } + }, - Some(message) = actor.receiver.recv() => { - actor.handle_message(message).await; - } + // Got a stop message...here we still continue procession until the + // JoinSet is empty. + Ok(Ok(None)) => (), + }; + } + + Some(message) = actor.receiver.recv() => { + actor.tasks_left -= 1; + actor.handle_message(message).await; + } - else => {} + else => {} + } } } } @@ -158,12 +225,12 @@ async fn run_client(mut actor: ClientActor) { impl ClientActor { - pub(super) fn new( buffer: usize - , rate_limit: u64 - , concurrency_limit: usize - , timeout: Duration - , receiver: mpsc::Receiver - , abort_rx: oneshot::Receiver ) -> anyhow::Result { + fn new( buffer: usize + , rate_limit: u64 + , concurrency_limit: usize + , timeout: Duration + , receiver: mpsc::Receiver + , abort_rx: oneshot::Receiver ) -> anyhow::Result { let client = ServiceBuilder::new() // Add some layers. . buffer(buffer) @@ -179,8 +246,6 @@ impl ClientActor { . map_err(anyhow::Error::msg) . boxed_clone(); - debug!("-> client: {:?}", client); - let mut tasks = JoinSet::new(); tasks.spawn(async move { @@ -190,8 +255,50 @@ impl ClientActor { let actions = HashMap::new(); let actions_idx = 0; + let tasks_left = buffer + concurrency_limit; + let body_timeout = None; + let mut default_headers = HeaderMap::new(); + default_headers.insert( + USER_AGENT, + HeaderValue::from_str(&( crate_name!().to_string() + "/" + + crate_version!() )).unwrap() ); + + Ok(Self { + body_timeout, + client, + tasks, + receiver, + actions, + default_headers, + actions_idx, + tasks_left }) + } - Ok(Self {timeout, client, tasks, receiver, actions, actions_idx}) + fn set_body_timeout(mut self, timeout: Option) -> Self { + self.body_timeout = timeout; + self + } + + fn set_origin(mut self, origin: Option) -> Self { + if let Some(origin) = origin { + self.default_headers.insert( + ORIGIN, + HeaderValue::from_str(origin.as_str()).unwrap() ); + } else { + self.default_headers.remove(ORIGIN); + } + self + } + + fn set_user_agent(mut self, user_agent: Option) -> Self { + if let Some(user_agent) = user_agent { + self.default_headers.insert( + USER_AGENT, + HeaderValue::from_str(user_agent.as_str()).unwrap() ); + } else { + self.default_headers.remove(USER_AGENT); + } + self } async fn handle_message(&mut self, message: ClientActorMessage) { @@ -203,7 +310,7 @@ impl ClientActor { Some(Download { ref filename, ref uri, respond_to: _ }) => { // spawn a task that does the work let client = self.client.clone(); - let timeout = self.timeout; + let timeout = self.body_timeout; let handle = ClientActorMessageHandle::Download { filename: filename.to_path_buf(), @@ -212,8 +319,12 @@ impl ClientActor { message: self.actions_idx, }; + let default_headers = self.default_headers.clone(); self.tasks.spawn(async move { - download::download(client, handle, timeout).await + download::download( client + , default_headers + , handle + , timeout ).await }); self.actions_idx += 1; @@ -245,7 +356,10 @@ impl ClientActorHandle { pub(super) fn new( buffer: usize , rate_limit: u64 , concurrency_limit: usize - , timeout: Duration ) -> Self { + , timeout: Duration + , use_body_timeout: bool + , origin: Option + , user_agent: Option ) -> Self { let (sender, receiver) = mpsc::channel(1); let (abort, abort_rx) = oneshot::channel::(); let actor = ClientActor::new( buffer @@ -254,7 +368,23 @@ impl ClientActorHandle { , timeout , receiver , abort_rx ) - . expect("Client create error"); + . expect("Client create error") + . set_origin(origin); + + let actor = if let Some(user_agent) = user_agent { + actor.set_user_agent(Some(user_agent)) + } else { + actor + }; + + let actor = if use_body_timeout { + actor.set_body_timeout(Some(timeout)) + } else { + actor + }; + + debug!("-> actor: {:?}", actor); + tokio::spawn(run_client(actor)); Self { sender, abort } diff --git a/src/client/data.rs b/src/client/data.rs index 1bef2c6..dde19ec 100644 --- a/src/client/data.rs +++ b/src/client/data.rs @@ -13,7 +13,7 @@ pub(super) async fn data( mut client: HttpClient let mut response = util::request( &mut client , "GET" , &uri - , HeaderMap::new() ) + , &HeaderMap::new() ) . await . map_err(map_dlerror!(message))?; diff --git a/src/client/download.rs b/src/client/download.rs index 4f0024a..9006898 100644 --- a/src/client/download.rs +++ b/src/client/download.rs @@ -24,8 +24,9 @@ use super::{ pub(super) async fn download( mut client: HttpClient + , headers: HeaderMap , mut message: ClientActorMessageHandle - , io_timeout: Duration ) -> JoinSetResult { + , io_timeout: Option ) -> JoinSetResult { let filename = message.filename(); let uri = message.uri(); @@ -33,10 +34,10 @@ pub(super) async fn download( mut client: HttpClient let mut from = file_size(&filename).await; // - get infos to uri - let headers = util::head(&mut client, &uri).await - . map_err(map_dlerror!(message))?; - let content_length = util::content_length(&headers).ok(); - let content_type = util::content_type(&headers).ok(); + let response_headers = util::head(&mut client, &uri, &headers).await + . map_err(map_dlerror!(message))?; + let content_length = util::content_length(&response_headers).ok(); + let content_type = util::content_type(&response_headers).ok(); message.set_state(DownloadState::GotHead); @@ -49,10 +50,10 @@ pub(super) async fn download( mut client: HttpClient } // - do the neccessry request. - let mut headers = HeaderMap::new(); + let mut headers = headers; headers.insert(RANGE, format!("bytes={}-", from).parse().unwrap()); - let mut response = util::request(&mut client, "GET", &uri, headers).await + let mut response = util::request(&mut client, "GET", &uri, &headers).await . map_err(map_dlerror!(message))?; // - open or create file @@ -103,12 +104,18 @@ async fn open_outfile(status: &StatusCode, filename: &Path) -> File { async fn store_body( file: &mut File , body: &mut DecompressionBody - , io_timeout: Duration ) -> anyhow::Result<()> { + , io_timeout: Option ) -> anyhow::Result<()> { let mut body = BodyDataStream::new(body); loop { - // give timeout somehow... probably from client. - let data = timeout(io_timeout, body.next()).await?; + let data_future = body.next(); + let data = if let Some(io_timeout) = io_timeout { + // give timeout somehow... probably from client. + timeout(io_timeout, body.next()).await? + } else { + data_future.await + }; + match data { None => break, diff --git a/src/client/error.rs b/src/client/error.rs index 37194b1..4718a72 100644 --- a/src/client/error.rs +++ b/src/client/error.rs @@ -1,14 +1,41 @@ use std::{error, fmt}; +use http::Response; + use super::ClientActorMessageHandle; +#[macro_export] +macro_rules! mk_dlerror { + ($message:ident, $($err:tt)*) => {{ + use $crate::client::error; + error::ClientError::new( $message.clone() + , Some(anyhow::anyhow!($($err)*) )) + }}; +} + +#[macro_export] +macro_rules! map_dlerror { + ($message:ident) => {{ + use $crate::client::error; + |e| error::ClientError::new( $message.clone() + , Some(anyhow::anyhow!(format!("{:?}", e)) )) + }}; +} + + #[derive(Debug)] pub(crate) struct ClientError { pub(super) action: ClientActorMessageHandle, pub(super) source: Option, } +#[derive(Debug)] +pub(super) struct RequestError { + pub(super) response: Option>, + pub(super) source: Option, +} + impl ClientError { pub(super) fn new( action: ClientActorMessageHandle @@ -35,3 +62,28 @@ impl fmt::Display for ClientError { } } } + +impl RequestError { + pub(super) fn new( response: Option> + , source: Option ) -> Self { + Self { response, source } + } +} + +impl error::Error for RequestError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match &self.source { + None => None, + Some(e) => Some(e.as_ref()), + } + } +} + +impl fmt::Display for RequestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.source { + None => write!(f, "request error: {:?}", self.response), + Some(err) => write!(f, "request error ({:?}): {}", self.response, err), + } + } +} diff --git a/src/client/util.rs b/src/client/util.rs index bfcb28a..04f328e 100644 --- a/src/client/util.rs +++ b/src/client/util.rs @@ -1,44 +1,55 @@ use anyhow::anyhow; -use http::{header::{CONTENT_LENGTH, CONTENT_TYPE}, request::Builder as RequestBuilder, HeaderMap, Response, Uri}; +use http::{ + header::{CONTENT_LENGTH, CONTENT_TYPE}, + request::Builder as RequestBuilder, + HeaderMap, + Response, + Uri +}; use reqwest::Body; use tower_http::decompression::DecompressionBody; use tower_http_client::ServiceExt as _; -use super::HttpClient; +use super::{error::RequestError, HttpClient}; use log::debug; pub(super) async fn request( client: &mut HttpClient - , method: &str - , uri: &Uri - , headers: HeaderMap ) -> anyhow::Result>> { + , method: &str + , uri: &Uri + , headers: &HeaderMap ) -> Result>, RequestError> { let mut request = RequestBuilder::new() . method(method) . uri(uri) - . body(Body::default())?; + . body(Body::default()) + . map_err(|e| RequestError::new(None, Some(e.into())))?; - request.headers_mut().extend(headers); + request.headers_mut().extend(headers.clone()); debug!("Request: {:?}", request); - let response = client.execute(request).await?; + match client.execute(request).await { + Err(e) => Err(RequestError::new(None, Some(e))), + Ok(response) => { + debug!("Response: {:?}", response.headers()); - debug!("Response: {:?}", response.headers()); - - anyhow::ensure!( response.status().is_success() - , "resonse status failed: {}" - , response.status() ); - - Ok(response) + if response.status().is_success() { + Ok(response) + } else { + Err(RequestError::new(Some(response.map(|_| ())), None)) + } + }, + } } pub(super) async fn head( client: &mut HttpClient - , uri: &Uri ) -> anyhow::Result { + , uri: &Uri + , headers: &HeaderMap ) -> Result { Ok( request( client , "HEAD" , uri - , HeaderMap::new() ).await?.headers().clone() ) + , headers ).await?.headers().clone() ) } pub(super) fn content_length(headers: &HeaderMap) -> anyhow::Result { diff --git a/src/main.rs b/src/main.rs index b55a7d5..d72c798 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,7 @@ use process::{enter_download_dir, ffmpeg, remove_download_dir}; #[derive(Debug, Parser)] +#[command(version, about, long_about = None)] struct Args { #[arg(short, long)] name: PathBuf, @@ -34,6 +35,12 @@ struct Args { concurrency: Option, #[arg(short, long)] timeout: Option, + #[arg(short = 'B', default_value_t = false)] + use_body_timeout: bool, + #[arg(short, long)] + origin: Option, + #[arg(short, long)] + agent: Option, } @@ -59,8 +66,8 @@ async fn main() -> anyhow::Result<()> { Err(anyhow!("Only filenames with .mp4 extension are allowed"))? } - let buffer = args.buffer.unwrap_or(1000); - let rate_limit = args.rate.unwrap_or(100); + let buffer = args.buffer.unwrap_or(20); + let rate_limit = args.rate.unwrap_or(20); let concurrency_limit = args.concurrency.unwrap_or(20); let timeout = args.timeout.unwrap_or(15); let timeout = Duration::from_secs(timeout); @@ -76,14 +83,16 @@ async fn main() -> anyhow::Result<()> { let client = ClientActorHandle::new( buffer , rate_limit , concurrency_limit - , timeout ); + , timeout + , args.use_body_timeout + , args.origin + , args.agent ); info!("Get segments..."); let m3u8_data = client.body_bytes(&m3u8_uri).await . ok_or(anyhow!("Unable to get body for: {}", m3u8_uri))?; let mut download = M3u8Download::new(m3u8_data, m3u8_uri).await?; - debug!("M3u8Download: {:?}", download); info!("Sending concurrent requests...");