diff --git a/Cargo.toml b/Cargo.toml index de4a6ea..17397b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ http-body-util = "0.1" log = "0.4" m3u8-rs = "6.0" reqwest = "0.12" -tokio = { version = "1.42", features = [ "macros", "rt-multi-thread" ] } -tower = { version = "0.5", features = [ "buffer", "limit", "timeout" ] } +tokio = { version = "1.42", features = [ "macros", "rt", "rt-multi-thread" ] } +tower = { version = "0.5", features = [ "limit", "timeout" ] } tower-http-client = "0.4" tower-reqwest = "0.4" diff --git a/src/client.rs b/src/client.rs index b0f1e77..fe02576 100644 --- a/src/client.rs +++ b/src/client.rs @@ -27,8 +27,11 @@ pub struct State { client: HttpClient, } +unsafe impl Send for State {} +unsafe impl Sync for State {} + impl State { - pub fn new(uri: &Uri, rate_limit: u64, concurrency_limit: usize) -> anyhow::Result + pub fn new(uri: &Uri, concurrency_limit: usize, timeout: Duration) -> anyhow::Result { let scheme = uri.scheme() . ok_or(anyhow!("Problem scheme in m3u8 uri"))?; @@ -42,12 +45,11 @@ impl State { scheme: scheme.clone(), auth: authority.clone(), base_path: base_path.to_string(), - timeout: Duration::from_secs(10), + timeout, client: ServiceBuilder::new() // Add some layers. - . buffer(64) - . rate_limit(rate_limit, Duration::from_secs(1)) . concurrency_limit(concurrency_limit) + . timeout(timeout) // Make client compatible with the `tower-http` layers. . layer(HttpClientLayer) . service(reqwest::Client::builder() @@ -61,10 +63,6 @@ impl State { Ok(state) } - pub(super) fn set_timeout(&mut self, timeout: Duration) { - self.timeout = timeout; - } - pub(super) async fn get_m3u8_segment_uris(&mut self, path_and_query: &str) -> anyhow::Result> { @@ -94,7 +92,7 @@ impl State { } pub(super) async fn get_m3u8_segment(&mut self, uri: &Uri) - -> Result<(), DownloadError> + -> Result { // I consider a missing path as fatal... there is absolutely nothing we can do about it // and we need all files from the playlist. @@ -118,7 +116,7 @@ impl State { None => (), Some(v) => if let Ok(v) = v.to_str() { if v == "true" { - return Ok(()) + return Ok(uri.clone()) } }, } @@ -174,7 +172,7 @@ impl State { } }; - Ok(()) + Ok(uri.clone()) } fn download_uri(&self, segment: &MediaSegment) -> anyhow::Result @@ -222,18 +220,12 @@ impl State { . method("HEAD") . uri(uri) . body(Body::default())?; - log!(Level::Debug, "{:?}", request); - // Send get request with timeout. - // let send_fut = self.client.get(uri).send()?; - let send_fut = self.client.execute(request); - let response = timeout(self.timeout, send_fut).await??; - + let response = self.client.execute(request).await?; anyhow::ensure!( response.status().is_success() , "resonse status failed: {}" , response.status() ); - log!(Level::Debug, "{:?}", response); let content_length: u64 = response.headers().get(CONTENT_LENGTH) @@ -259,18 +251,12 @@ impl State { . uri(uri) . header(RANGE, format!("bytes={}-", from)) . body(Body::default())?; - log!(Level::Debug, "{:?}", request); - // Send get request with timeout. - // let send_fut = self.client.get(uri).send()?; - let send_fut = self.client.execute(request); - let response = timeout(self.timeout, send_fut).await??; - + let response = self.client.execute(request).await?; anyhow::ensure!( response.status().is_success() , "resonse status failed: {}" , response.status() ); - log!(Level::Debug, "{:?}", response); Ok(response) diff --git a/src/main.rs b/src/main.rs index 16824b7..4607238 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ use anyhow::anyhow; use clap::Parser; use env_logger::Env; use http::Uri; +use tokio::task::JoinSet; use log::{log, Level}; @@ -25,8 +26,6 @@ struct Args { #[arg(short, long)] timeout: Option, #[arg(short, long)] - rate: Option, - #[arg(short, long)] concurrency: Option, } @@ -47,22 +46,14 @@ async fn main() -> anyhow::Result<()> { log!(Level::Info, "-> Creating an HTTP client with Tower layers..."); - let rate = args.rate.unwrap_or(10); let concurrency = args.concurrency.unwrap_or(10); + let timeout = args.timeout.unwrap_or(10); + let timeout = Duration::from_secs(timeout); let m3u8_uri = Uri::try_from(&args.url)?; let m3u8_path_and_query = m3u8_uri.path_and_query() . ok_or(anyhow!("Problem path and query in m3u8 uri"))?; - let mut state = client::State::new(&m3u8_uri, rate, concurrency)?; - - if let Some(timeout) = args.timeout { - state.set_timeout(Duration::from_secs(timeout)); - } - - // I think about a worker pool... probably of concurrency_limit amount. - // The worker needs to get the url. Our first target is to store the - // data on a file with the same name as the last part of the URL. - // CURRENTLY I just create a task for each download. + let mut state = client::State::new(&m3u8_uri, concurrency, timeout)?; log!(Level::Info, "-> Get segments..."); @@ -70,32 +61,34 @@ async fn main() -> anyhow::Result<()> { log!(Level::Info, "-> Sending concurrent requests..."); - 'working: while ! segments.is_empty() { - let mut tasks = vec![]; - while let Some(segment) = segments.pop() { - log!(Level::Info, "download segment: {}", segment); - let state = state.clone(); - tasks.push(tokio::spawn(async move { - let mut state = state.clone(); - state.get_m3u8_segment(&segment).await - })); - } + let mut join_set = JoinSet::new(); + while let Some(segment) = segments.pop() { + log!(Level::Info, " -> Spawn task for: {}", segment); + let mut state = state.clone(); + join_set.spawn(async move { + state.get_m3u8_segment(&segment).await + }); + } - let results = futures_util::future::join_all(tasks).await; - for result in &results { - match result { - Err(e) => { - log!(Level::Error, "FATAL Join failed: {}", e); - break 'working - }, - - Ok(Err(e)) => { - log!(Level::Warn, "Retry failed download: {}", e); - segments.push(e.uri.clone()); - }, - - _ => (), - } + 'working: while let Some(result) = join_set.join_next().await { + match result { + Err(e) => { + log!(Level::Error, "FATAL Join failed: {}", e); + break 'working + }, + + Ok(Err(e)) => { + log!(Level::Warn, "Retry failed download: {:?}", e); + log!(Level::Info, " -> Spawn task for: {}", e.uri); + let mut state = state.clone(); + join_set.spawn(async move { + state.get_m3u8_segment(&e.uri).await + }); + }, + + Ok(Ok(v)) => { + log!(Level::Info, "Done download: {}", v); + }, } }