diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..d4b448a --- /dev/null +++ b/src/client.rs @@ -0,0 +1,215 @@ +use std::{path::Path, time::Duration}; + +use anyhow::anyhow; +use futures_util::StreamExt as _; +use http::{header::CONTENT_TYPE, uri::{Authority, Scheme}, Request, Response, Uri}; +use http_body_util::BodyDataStream; +use log::{log, Level}; +use m3u8_rs::Playlist; +use reqwest::{redirect::Policy, Body}; +use tokio::{fs::File, io::AsyncWriteExt as _, time::timeout}; +use tower::{ServiceBuilder, ServiceExt as _}; +use tower_http_client::ServiceExt as _; +use tower_reqwest::HttpClientLayer; + +use crate::download_error::DownloadError; + +type HttpClient = tower::util::BoxCloneService, Response, anyhow::Error>; + +#[derive(Clone, Debug)] +pub struct State { + scheme: Scheme, + auth: Authority, + base_path: String, + timeout: Duration, + client: HttpClient, +} + +impl State { + pub fn new(uri: &Uri, rate_limit: u64, concurrency_limit: usize) -> anyhow::Result + { + let scheme = uri.scheme() + . ok_or(anyhow!("Problem scheme in m3u8 uri"))?; + let authority = uri.authority() + . ok_or(anyhow!("Problem authority in m3u8 uri"))?; + let base_path = Path::new(uri.path()).parent() + . ok_or(anyhow!("Path problem"))? + . to_str() + . ok_or(anyhow!("Path problem"))?; + let state = State { + scheme: scheme.clone(), + auth: authority.clone(), + base_path: base_path.to_string(), + timeout: Duration::from_secs(30), + client: ServiceBuilder::new() + // Add some layers. + . buffer(64) + . rate_limit(rate_limit, Duration::from_secs(1)) + . concurrency_limit(concurrency_limit) + // Make client compatible with the `tower-http` layers. + . layer(HttpClientLayer) + . service(reqwest::Client::builder() + . redirect(Policy::limited(5)) + . build()? ) + . map_err(anyhow::Error::msg) + . boxed_clone(), + }; + log!(Level::Debug, "-> state: {:?}", state); + + Ok(state) + } + + pub(super) fn set_timeout(&mut self, timeout: Duration) { + self.timeout = timeout; + } + + pub(super) async fn get_m3u8_segment_uris(&mut self, path_and_query: &str) + -> anyhow::Result> + { + let uri = Uri::builder() + . scheme(self.scheme.clone()) + . authority(self.auth.clone()) + . path_and_query(path_and_query) + . build()?; + let filename = Path::new(uri.path()).file_name().ok_or(anyhow!("name error"))?; + let mut file = File::create(filename).await?; + + let mut response = timeout(self.timeout, self.client.get(uri).send()?).await??; + + anyhow::ensure!(response.status().is_success(), "response failed"); + + // read body into Vec + let mut body = vec![]; + let mut body_stream = BodyDataStream::new(response.body_mut()); + 'label: loop { + let data = timeout(self.timeout, body_stream.next()).await?; + match data { + None => break 'label, + Some(Err(_)) => break 'label, + Some(Ok(data)) => { + body.append(&mut Vec::from(data.as_ref())); + }, + } + }; + + match m3u8_rs::parse_playlist(&body) { + Result::Err(e) => Err(anyhow!("m3u8 parse error: {}", e)), + Result::Ok((_, Playlist::MasterPlaylist(_))) => + Err(anyhow!("Master playlist not supported now")), + Result::Ok((_, Playlist::MediaPlaylist(pl))) => { + let segments: anyhow::Result> = pl.segments.iter().map(|s| { + Ok(match Uri::try_from(s.uri.clone()) { + Ok(uri) => { + let mut new_segment = s.clone(); + let filename = Path::new(uri.path()) + . file_name() + . ok_or(anyhow!("name error"))? + . to_str() + . ok_or(anyhow!("Error getting filename from uri"))?; + let query = uri.query() + . map(|q| "?".to_owned() + q) + . unwrap_or("".to_string()); + new_segment.uri = (filename.to_owned() + &query).to_string(); + new_segment } + + Err(_) => s.clone() + }) + }).collect(); + let mut out_pl = pl.clone(); + out_pl.segments = segments?; + let mut file_data: Vec = Vec::new(); + out_pl.write_to(&mut file_data)?; + file.write_all(&file_data).await?; + + let uris: anyhow::Result> = pl.segments.iter().map(|s| { + match Uri::try_from(s.uri.clone()) { + Ok(uri) => { + let scheme = uri.scheme() + . or(Some(&self.scheme)) + . ok_or(anyhow!("No scheme in Uri"))?; + let auth = uri.authority() + . or(Some(&self.auth)) + . ok_or(anyhow!("No authority in Uri"))?; + let path_and_query = uri.path_and_query() + . ok_or(anyhow!("No path in Uri"))?; + Ok(Uri::builder() . scheme(scheme.clone()) + . authority(auth.clone()) + . path_and_query(path_and_query.clone()) + . build()?) } + + Err(e) => { + log!(Level::Debug, "Uri parse error: {:?} - {}", e, s.uri); + Ok(Uri::builder() . scheme(self.scheme.clone()) + . authority(self.auth.clone()) + . path_and_query(self.base_path.clone() + "/" + &s.uri) + . build()?) } + } + }).collect(); + + uris + }, + } + } + + pub(super) async fn get_m3u8_segment(&mut self, uri: &Uri) + -> Result<(), DownloadError> + { + // Send get request with timeout. + let send_fut = self.client + . get(uri) + . send() + . map_err(|e| DownloadError::new(uri.clone(), Some(e.into())))?; + let mut response = timeout(self.timeout, send_fut).await + . map_err(|e| DownloadError::new(uri.clone(), Some(e.into())) )? + . map_err(|e| DownloadError::new(uri.clone(), Some(e)) )?; + + // This handling needs to be more elaborate... a distingtion needs to be made + // between temporary and permanent failures. + if ! response.status().is_success() { + return Err(DownloadError::new( + uri.clone() + , Some(anyhow::Error::msg("request unsuccessfull")) + )); + } + + // We always need the content-type to be able to decide + let content_type = response.headers()[CONTENT_TYPE].to_str() + . expect("No content-type header found in response"); + + if content_type != "video/MP2T" { + let message = format!("unexpected content-type: {}", content_type); + log!(Level::Debug, "{}", message); + return Err(DownloadError::new( uri.clone() + , Some(anyhow::Error::msg(message)) )); + } + + // I consider a missing path as fatal... there is absolutely nothing we can do about it + // and we need all files from the playlist. + let path_and_query = uri.path_and_query().expect("No path and query").as_str(); + let filename = Path::new(path_and_query) + . file_name() + . expect("no filename in path_and_query"); + let mut file = File::create(filename).await + . expect("can not create file for writing"); + + // read body into file as stream + let mut body_stream = BodyDataStream::new(response.body_mut()); + + 'label: loop { + let data = timeout(self.timeout, body_stream.next()).await + . map_err(|e| DownloadError::new(uri.clone(), Some(e.into())))?; + match data { + None => break 'label, + Some(Err(e)) => + return Err(DownloadError::new(uri.clone(), Some(e.into()))), + Some(Ok(data)) => { + file.write_all(data.as_ref()).await + . map_err(|e| + DownloadError::new(uri.clone(), Some(e.into())) )?; + }, + } + }; + + Ok(()) + } +} diff --git a/src/download_error.rs b/src/download_error.rs new file mode 100644 index 0000000..50977c0 --- /dev/null +++ b/src/download_error.rs @@ -0,0 +1,33 @@ +use std::{error, fmt}; + +use http::Uri; + +#[derive(Debug)] +pub(super) struct DownloadError { + pub(super) uri: Uri, + pub(super) source: Option, +} + +impl DownloadError { + pub(super) fn new(uri: Uri, source: Option) -> Self { + Self { uri, source } + } +} + +impl error::Error for DownloadError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match &self.source { + None => None, + Some(e) => Some(e.as_ref()), + } + } +} + +impl fmt::Display for DownloadError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.source { + None => write!(f, "download error: {}", self.uri), + Some(err) => write!(f, "download error ({}): {}", self.uri, err), + } + } +} diff --git a/src/main.rs b/src/main.rs index d8a2ec5..fab7306 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,12 @@ -use std::{path::Path, time::Duration}; +mod download_error; +mod client; + +use std::time::Duration; use anyhow::anyhow; use clap::Parser; use env_logger::Env; -use futures_util::StreamExt; -use http::{header::CONTENT_TYPE, uri::{Authority, Scheme}, Request, Response, Uri}; -use http_body_util::BodyDataStream; -use m3u8_rs::Playlist; -use reqwest::{redirect::Policy, Body}; -use tokio::{fs::File, io::AsyncWriteExt, task::{JoinError, JoinHandle}, time::timeout}; -use tower::{ServiceBuilder, ServiceExt as _}; -use tower_http_client::ServiceExt as _; -use tower_reqwest::HttpClientLayer; +use http::Uri; use log::{log, Level}; @@ -26,150 +21,15 @@ struct DownloadMessage { #[derive(Debug, Parser)] struct Args { #[arg(short, long)] - url: String -} - -type HttpClient = tower::util::BoxCloneService, Response, anyhow::Error>; - -#[derive(Clone, Debug)] -struct State { - scheme: Scheme, - auth: Authority, - base_path: String, - client: HttpClient, + url: String, + #[arg(short, long)] + timeout: Option, + #[arg(short, long)] + rate: Option, + #[arg(short, long)] + concurrency: Option, } -impl State { - async fn get_m3u8_segment_uris(&mut self, path_and_query: &str) -> anyhow::Result> { - let uri = Uri::builder() - . scheme(self.scheme.clone()) - . authority(self.auth.clone()) - . path_and_query(path_and_query) - . build()?; - let filename = Path::new(uri.path()).file_name().ok_or(anyhow!("name error"))?; - let mut file = File::create(filename).await?; - - let mut response = timeout(Duration::from_secs(10), self.client.get(uri).send()?).await??; - - anyhow::ensure!(response.status().is_success(), "response failed"); - - let mut body = vec![]; - let mut body_stream = BodyDataStream::new(response.body_mut()); - 'label: loop { - let data = timeout(Duration::from_secs(10), body_stream.next()).await?; - match data { - None => break 'label, - Some(Err(_)) => break 'label, - Some(Ok(data)) => { - body.append(&mut Vec::from(data.as_ref())); - }, - } - }; - - match m3u8_rs::parse_playlist(&body) { - Result::Err(e) => Err(anyhow!("m3u8 parse error: {}", e)), - Result::Ok((_, Playlist::MasterPlaylist(_))) => - Err(anyhow!("Master playlist not supported now")), - Result::Ok((_, Playlist::MediaPlaylist(pl))) => { - let segments: anyhow::Result> = pl.segments.iter().map(|s| { - Ok(match Uri::try_from(s.uri.clone()) { - Ok(uri) => { - let mut new_segment = s.clone(); - let filename = Path::new(uri.path()) - . file_name() - . ok_or(anyhow!("name error"))? - . to_str() - . ok_or(anyhow!("Error getting filename from uri"))?; - let query = uri.query() - . map(|q| "?".to_owned() + q) - . unwrap_or("".to_string()); - new_segment.uri = (filename.to_owned() + &query).to_string(); - new_segment } - - Err(_) => s.clone() - }) - }).collect(); - let mut out_pl = pl.clone(); - out_pl.segments = segments?; - let mut file_data: Vec = Vec::new(); - out_pl.write_to(&mut file_data)?; - file.write_all(&file_data).await?; - - let uris: anyhow::Result> = pl.segments.iter().map(|s| { - match Uri::try_from(s.uri.clone()) { - Ok(uri) => { - let scheme = uri.scheme() - . or(Some(&self.scheme)) - . ok_or(anyhow!("No scheme in Uri"))?; - let auth = uri.authority() - . or(Some(&self.auth)) - . ok_or(anyhow!("No authority in Uri"))?; - let path_and_query = uri.path_and_query() - . ok_or(anyhow!("No path in Uri"))?; - Ok(Uri::builder() . scheme(scheme.clone()) - . authority(auth.clone()) - . path_and_query(path_and_query.clone()) - . build()?) } - - Err(e) => { - log!(Level::Debug, "Uri parse error: {:?} - {}", e, s.uri); - Ok(Uri::builder() . scheme(self.scheme.clone()) - . authority(self.auth.clone()) - . path_and_query(self.base_path.clone() + "/" + &s.uri) - . build()?) } - } - }).collect(); - - uris - }, - } - } - - async fn get_m3u8_segment(&mut self, uri: &Uri) -> Result { - if let Ok(send_fut) = self.client.get(uri).send() { - if let Ok(mut response) = timeout(Duration::from_secs(10), send_fut).await.or(Err(uri.clone()))? { - if response.status().is_success() { - let content_type = match response.headers()[CONTENT_TYPE].to_str() { - Err(_) => { - log!(Level::Debug, "Error getting content-type"); - Err(uri.clone()) - }, - Ok(h) => Ok(h) - }?; - log!(Level::Debug, "CONTENT-TYPE: {}", content_type); - if content_type != "video/MP2T" { - log!(Level::Error, "{} is not video/MP2T", content_type); - return Err(uri.clone()); - } - - let path_and_query = uri.path_and_query().expect("No path and query").as_str(); - let filename = Path::new(path_and_query).file_name().ok_or(uri.clone())?; - let mut file = File::create(filename).await.or(Err(uri.clone()))?; - - let mut body_stream = BodyDataStream::new(response.body_mut()); - 'label: loop { - let data = timeout(Duration::from_secs(10), body_stream.next()).await.or(Err(uri.clone()))?; - match data { - None => break 'label, - Some(Err(_)) => return Err(uri.clone()), - Some(Ok(data)) => { - file.write_all(data.as_ref()).await.or(Err(uri.clone()))?; - }, - } - }; - - Ok(uri.clone()) - } else { - Err(uri.clone()) - } - } else { - Err(uri.clone()) - } - } else { - Err(uri.clone()) - } - } -} #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -185,38 +45,19 @@ async fn main() -> anyhow::Result<()> { let args = Args::parse(); log!(Level::Debug, "-> Arguments: {:?}", args); - log!(Level::Info, "-> Creating an HTTP client with Tower layers..."); + let rate = args.rate.unwrap_or(10); + let concurrency = args.concurrency.unwrap_or(10); + let m3u8_uri = Uri::try_from(&args.url)?; - let m3u8_scheme = m3u8_uri.scheme() - . ok_or(anyhow!("Problem scheme in m3u8 uri"))?; - let m3u8_auth = m3u8_uri.authority() - . ok_or(anyhow!("Problem authority in m3u8 uri"))?; - let m3u8_base_path = Path::new(m3u8_uri.path()).parent() - . ok_or(anyhow!("Path problem"))? - . to_str() - . ok_or(anyhow!("Path problem"))?; let m3u8_path_and_query = m3u8_uri.path_and_query() . ok_or(anyhow!("Problem path and query in m3u8 uri"))?; - let mut state = State { - scheme: m3u8_scheme.clone(), - auth: m3u8_auth.clone(), - base_path: m3u8_base_path.to_string(), - client: ServiceBuilder::new() - // Add some layers. - . buffer(64) - . rate_limit(10, Duration::from_secs(1)) - . concurrency_limit(10) - // Make client compatible with the `tower-http` layers. - . layer(HttpClientLayer) - . service(reqwest::Client::builder() - . redirect(Policy::limited(5)) - . build()? ) - . map_err(anyhow::Error::msg) - . boxed_clone(), - }; - log!(Level::Debug, "-> state: {:?}", state); + let mut state = client::State::new(&m3u8_uri, rate, concurrency)?; + + if let Some(timeout) = args.timeout { + state.set_timeout(Duration::from_secs(timeout)); + } // I think about a worker pool... probably of concurrency_limit amount. // The worker needs to get the url. Our first target is to store the @@ -230,7 +71,7 @@ async fn main() -> anyhow::Result<()> { log!(Level::Info, "-> Sending concurrent requests..."); 'working: while ! segments.is_empty() { - let mut tasks: Vec>> = vec![]; + let mut tasks = vec![]; while let Some(segment) = segments.pop() { let state = state.clone(); tasks.push(tokio::spawn(async move { @@ -239,7 +80,7 @@ async fn main() -> anyhow::Result<()> { })); } - let results: Vec, JoinError>> = futures_util::future::join_all(tasks).await; + let results = futures_util::future::join_all(tasks).await; for result in &results { match result { Err(e) => { @@ -249,7 +90,7 @@ async fn main() -> anyhow::Result<()> { Ok(Err(e)) => { log!(Level::Info, "Retry failed download: {}", e); - segments.push(e.clone()); + segments.push(e.uri.clone()); }, _ => (),