Browse Source

Move to JoinSet

main
Georg Hopp 11 months ago
parent
commit
c15ade579f
Signed by: ghopp GPG Key ID: 4C5D226768784538
  1. 4
      Cargo.toml
  2. 36
      src/client.rs
  3. 69
      src/main.rs

4
Cargo.toml

@ -13,7 +13,7 @@ http-body-util = "0.1"
log = "0.4"
m3u8-rs = "6.0"
reqwest = "0.12"
tokio = { version = "1.42", features = [ "macros", "rt-multi-thread" ] }
tower = { version = "0.5", features = [ "buffer", "limit", "timeout" ] }
tokio = { version = "1.42", features = [ "macros", "rt", "rt-multi-thread" ] }
tower = { version = "0.5", features = [ "limit", "timeout" ] }
tower-http-client = "0.4"
tower-reqwest = "0.4"

36
src/client.rs

@ -27,8 +27,11 @@ pub struct State {
client: HttpClient,
}
unsafe impl Send for State {}
unsafe impl Sync for State {}
impl State {
pub fn new(uri: &Uri, rate_limit: u64, concurrency_limit: usize) -> anyhow::Result<Self>
pub fn new(uri: &Uri, concurrency_limit: usize, timeout: Duration) -> anyhow::Result<Self>
{
let scheme = uri.scheme()
. ok_or(anyhow!("Problem scheme in m3u8 uri"))?;
@ -42,12 +45,11 @@ impl State {
scheme: scheme.clone(),
auth: authority.clone(),
base_path: base_path.to_string(),
timeout: Duration::from_secs(10),
timeout,
client: ServiceBuilder::new()
// Add some layers.
. buffer(64)
. rate_limit(rate_limit, Duration::from_secs(1))
. concurrency_limit(concurrency_limit)
. timeout(timeout)
// Make client compatible with the `tower-http` layers.
. layer(HttpClientLayer)
. service(reqwest::Client::builder()
@ -61,10 +63,6 @@ impl State {
Ok(state)
}
pub(super) fn set_timeout(&mut self, timeout: Duration) {
self.timeout = timeout;
}
pub(super) async fn get_m3u8_segment_uris(&mut self, path_and_query: &str)
-> anyhow::Result<Vec<Uri>>
{
@ -94,7 +92,7 @@ impl State {
}
pub(super) async fn get_m3u8_segment(&mut self, uri: &Uri)
-> Result<(), DownloadError>
-> Result<Uri, DownloadError>
{
// I consider a missing path as fatal... there is absolutely nothing we can do about it
// and we need all files from the playlist.
@ -118,7 +116,7 @@ impl State {
None => (),
Some(v) => if let Ok(v) = v.to_str() {
if v == "true" {
return Ok(())
return Ok(uri.clone())
}
},
}
@ -174,7 +172,7 @@ impl State {
}
};
Ok(())
Ok(uri.clone())
}
fn download_uri(&self, segment: &MediaSegment) -> anyhow::Result<Uri>
@ -222,18 +220,12 @@ impl State {
. method("HEAD")
. uri(uri)
. body(Body::default())?;
log!(Level::Debug, "{:?}", request);
// Send get request with timeout.
// let send_fut = self.client.get(uri).send()?;
let send_fut = self.client.execute(request);
let response = timeout(self.timeout, send_fut).await??;
let response = self.client.execute(request).await?;
anyhow::ensure!( response.status().is_success()
, "resonse status failed: {}"
, response.status() );
log!(Level::Debug, "{:?}", response);
let content_length: u64 = response.headers().get(CONTENT_LENGTH)
@ -259,18 +251,12 @@ impl State {
. uri(uri)
. header(RANGE, format!("bytes={}-", from))
. body(Body::default())?;
log!(Level::Debug, "{:?}", request);
// Send get request with timeout.
// let send_fut = self.client.get(uri).send()?;
let send_fut = self.client.execute(request);
let response = timeout(self.timeout, send_fut).await??;
let response = self.client.execute(request).await?;
anyhow::ensure!( response.status().is_success()
, "resonse status failed: {}"
, response.status() );
log!(Level::Debug, "{:?}", response);
Ok(response)

69
src/main.rs

@ -7,6 +7,7 @@ use anyhow::anyhow;
use clap::Parser;
use env_logger::Env;
use http::Uri;
use tokio::task::JoinSet;
use log::{log, Level};
@ -25,8 +26,6 @@ struct Args {
#[arg(short, long)]
timeout: Option<u64>,
#[arg(short, long)]
rate: Option<u64>,
#[arg(short, long)]
concurrency: Option<usize>,
}
@ -47,22 +46,14 @@ async fn main() -> anyhow::Result<()> {
log!(Level::Info, "-> Creating an HTTP client with Tower layers...");
let rate = args.rate.unwrap_or(10);
let concurrency = args.concurrency.unwrap_or(10);
let timeout = args.timeout.unwrap_or(10);
let timeout = Duration::from_secs(timeout);
let m3u8_uri = Uri::try_from(&args.url)?;
let m3u8_path_and_query = m3u8_uri.path_and_query()
. ok_or(anyhow!("Problem path and query in m3u8 uri"))?;
let mut state = client::State::new(&m3u8_uri, rate, concurrency)?;
if let Some(timeout) = args.timeout {
state.set_timeout(Duration::from_secs(timeout));
}
// I think about a worker pool... probably of concurrency_limit amount.
// The worker needs to get the url. Our first target is to store the
// data on a file with the same name as the last part of the URL.
// CURRENTLY I just create a task for each download.
let mut state = client::State::new(&m3u8_uri, concurrency, timeout)?;
log!(Level::Info, "-> Get segments...");
@ -70,32 +61,34 @@ async fn main() -> anyhow::Result<()> {
log!(Level::Info, "-> Sending concurrent requests...");
'working: while ! segments.is_empty() {
let mut tasks = vec![];
while let Some(segment) = segments.pop() {
log!(Level::Info, "download segment: {}", segment);
let state = state.clone();
tasks.push(tokio::spawn(async move {
let mut state = state.clone();
state.get_m3u8_segment(&segment).await
}));
}
let mut join_set = JoinSet::new();
while let Some(segment) = segments.pop() {
log!(Level::Info, " -> Spawn task for: {}", segment);
let mut state = state.clone();
join_set.spawn(async move {
state.get_m3u8_segment(&segment).await
});
}
let results = futures_util::future::join_all(tasks).await;
for result in &results {
match result {
Err(e) => {
log!(Level::Error, "FATAL Join failed: {}", e);
break 'working
},
Ok(Err(e)) => {
log!(Level::Warn, "Retry failed download: {}", e);
segments.push(e.uri.clone());
},
_ => (),
}
'working: while let Some(result) = join_set.join_next().await {
match result {
Err(e) => {
log!(Level::Error, "FATAL Join failed: {}", e);
break 'working
},
Ok(Err(e)) => {
log!(Level::Warn, "Retry failed download: {:?}", e);
log!(Level::Info, " -> Spawn task for: {}", e.uri);
let mut state = state.clone();
join_set.spawn(async move {
state.get_m3u8_segment(&e.uri).await
});
},
Ok(Ok(v)) => {
log!(Level::Info, "Done download: {}", v);
},
}
}

Loading…
Cancel
Save