diff --git a/src/client.rs b/src/client.rs index f725702..fb90c16 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,30 +4,31 @@ mod util; mod error; mod message; -use std::{collections::HashMap, path::Path, time::Duration}; +use std::{collections::HashMap, path::Path, str::FromStr, time::Duration}; use clap::{crate_name, crate_version}; -use error::ClientError; +use download::{file_size, open_outfile, store_body}; +use error::{ClientError, RequestError}; use http::{ - header::{ORIGIN, USER_AGENT}, - HeaderMap, - HeaderValue, - Request, - Response, - Uri, + header::{CONTENT_LENGTH, CONTENT_TYPE, ORIGIN, RANGE, USER_AGENT}, request::Builder as RequestBuilder, HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode, Uri }; +use http_body_util::BodyDataStream; use message::{ClientActorMessage, ClientActorMessageHandle}; use reqwest::{redirect::Policy, Body}; use tokio::{ + fs::File, select, sync::{mpsc, oneshot}, - task::JoinSet, + task::JoinSet, time::timeout, }; use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt as _}; use tower_http::decompression::{DecompressionBody, DecompressionLayer}; use tower_reqwest::HttpClientLayer; use log::{debug, error, info}; +use util::head; + +use crate::map_dlerror; #[derive(Clone, Debug)] @@ -45,6 +46,12 @@ type JoinSetResult = Result, ClientError>; type HttpClient = BoxCloneService, Response>, anyhow::Error>; +// =========== + + +// =========== + + #[derive(Debug)] struct ClientActor { body_timeout: Option, diff --git a/src/client/download.rs b/src/client/download.rs index 9006898..0c7a935 100644 --- a/src/client/download.rs +++ b/src/client/download.rs @@ -69,8 +69,9 @@ pub(super) async fn download( mut client: HttpClient Ok(Some(message)) } -async fn file_size(filename: &Path) -> u64 { +pub(crate) async fn file_size(filename: impl AsRef) -> u64 { // - get all informations to eventually existing file + let filename = filename.as_ref(); let metadata = match symlink_metadata(filename).await { Ok(metadata) => Some(metadata), Err(error) => match error.kind() { @@ -83,7 +84,8 @@ async fn file_size(filename: &Path) -> u64 { metadata.map_or(0, |m| m.len()) } -async fn open_outfile(status: &StatusCode, filename: &Path) -> File { +pub(crate) async fn open_outfile(status: &StatusCode, filename: impl AsRef) -> File { + let filename = filename.as_ref(); match status { &StatusCode::PARTIAL_CONTENT => // Here we assume that this response only comes if the requested @@ -102,9 +104,9 @@ async fn open_outfile(status: &StatusCode, filename: &Path) -> File { } } -async fn store_body( file: &mut File - , body: &mut DecompressionBody - , io_timeout: Option ) -> anyhow::Result<()> { +pub(crate) async fn store_body( file: &mut File + , body: &mut DecompressionBody + , io_timeout: Option ) -> anyhow::Result<()> { let mut body = BodyDataStream::new(body); loop { @@ -118,11 +120,7 @@ async fn store_body( file: &mut File match data { None => break, - - Some(Err(e)) => { - return Err(anyhow!(e)); - } - + Some(Err(e)) => Err(anyhow!(e))?, Some(Ok(data)) => { file . write_all(data.as_ref()).await?; file . flush().await?; diff --git a/src/client_actor.rs b/src/client_actor.rs new file mode 100644 index 0000000..346318f --- /dev/null +++ b/src/client_actor.rs @@ -0,0 +1,150 @@ +mod message; +mod error; +mod util; + +use std::{collections::HashMap, path::Path}; + +use error::ClientActorError; +use http::{HeaderMap, Uri}; +use message::{ClientActorMessage, ClientActorMessageHandle, DownloadState}; +use tokio::{sync::{mpsc, oneshot}, task::JoinSet}; + +use super::client_new::Client; + + +type ActionIndex = u64; +type ClientTaskResult = Result, ClientActorError>; +type DownloadResult = Result, ClientActorError>; + + +#[derive(Debug)] +struct ClientActor { + client: Client, + tasks: JoinSet, + max_tasks: usize, + actions: HashMap, + actions_idx: ActionIndex, + + receiver: mpsc::Receiver, +} + +pub(super) struct ClientActorHandle { + sender: mpsc::Sender, + abort: oneshot::Sender, +} + + +impl ClientActor { + fn new( client: Client + , max_tasks: usize + , receiver: mpsc::Receiver + , abort_rx: oneshot::Receiver ) -> anyhow::Result { + let mut tasks = JoinSet::new(); + + tasks.spawn(async move { + let _ = abort_rx.await; + Ok(None) + }); + + let actions = HashMap::new(); + let actions_idx = 0; + + Ok(Self { + client, + tasks, + max_tasks, + actions, + actions_idx, + receiver }) + } + + async fn handle_message(&mut self, message: ClientActorMessage) { + self.actions.insert(self.actions_idx, message); + + use ClientActorMessage::{Download, GetData}; + + match self.actions.get(&self.actions_idx) { + Some(Download { ref filename, ref uri, .. }) => { + // spawn a task that does the work + let mut client = self.client.clone(); + let filename = filename.to_path_buf(); + let uri = uri.clone(); + let message = self.actions_idx; + let handle = ClientActorMessageHandle::Download { + filename, + uri, + state: None, + message, + }; + + self.tasks.spawn(async move { + client.download(handle.filename(), &handle.uri(), &HeaderMap::new()).await + . map_err(|source| ClientActorError { action: handle.clone() + , source })?; + Ok(Some(handle)) + }); + }, + + Some(GetData { ref uri, .. }) => { + // spawn a task that does the work + let mut client = self.client.clone(); + let uri = uri.clone(); + let handle = ClientActorMessageHandle::GetData { + uri, + buffer: None, + message: self.actions_idx, + }; + + self.tasks.spawn(async move { + client.data(&handle.uri(), &HeaderMap::new()).await + . map_err(|source| ClientActorError { action: handle.clone() + , source })?; + Ok(Some(handle)) + }); + }, + + None => (), + } + + self.actions_idx += 1; + } +} + +impl ClientActorHandle { + pub(super) fn new(client: Client, max_tasks: usize) -> Self { + let (sender, receiver) = mpsc::channel(1); + let (abort, abort_rx) = oneshot::channel::(); + let actor = ClientActor::new(client, max_tasks, receiver, abort_rx ) + . expect("Client create error"); + + tokio::spawn(util::run_client(actor)); + + Self { sender, abort } + } + + pub(super) fn stop(self) { + let _ = self.abort.send(Ok(None)); + drop(self.sender); + } + + pub(super) async fn download( &self + , filename: impl AsRef + , uri: &Uri ) -> DownloadResult { + let filename = filename.as_ref().to_path_buf(); + let uri = uri.to_owned(); + let (send, receive) = oneshot::channel(); + let msg = ClientActorMessage::Download { filename, uri, respond_to: send }; + + let _ = self.sender.send(msg).await; + receive.await.expect("Actor cancelled unexpected") + } + + pub(super) async fn body_bytes(&self, uri: &Uri) -> Option> { + let uri = uri.to_owned(); + let (send, receive) = oneshot::channel(); + let msg = ClientActorMessage::GetData { uri, respond_to: send }; + + let _ = self.sender.send(msg).await; + receive.await.expect("Actor cancelled unexpected") + } +} diff --git a/src/client_actor/error.rs b/src/client_actor/error.rs new file mode 100644 index 0000000..6b8c840 --- /dev/null +++ b/src/client_actor/error.rs @@ -0,0 +1,51 @@ +use std::{error, fmt}; + +use super::message::ClientActorMessageHandle; + + +/* +#[macro_export] +macro_rules! mk_ca_error { + ($message:ident, $($err:tt)*) => {{ + use $crate::client::error; + error::ClientActorError::new( $message.clone() + , anyhow::anyhow!($($err)*) ) + }}; +} + +#[macro_export] +macro_rules! map_ca_error { + ($message:ident) => {{ + use $crate::client::error; + |e| error::ClientActorError::new( $message.clone() + , anyhow::anyhow!(format!("{:?}", e)) ) + }}; +} +*/ + + +#[derive(Debug)] +pub(super) struct ClientActorError { + pub(super) action: ClientActorMessageHandle, + pub(super) source: anyhow::Error, +} + +impl ClientActorError { + pub(super) fn new( action: ClientActorMessageHandle + , source: anyhow::Error ) -> Self { + let action = action.to_owned(); + Self { action, source } + } +} + +impl error::Error for ClientActorError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + Some(self.source.as_ref()) + } +} + +impl fmt::Display for ClientActorError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "download error ({:?}): {}", self.action, self.source) + } +} diff --git a/src/client_actor/message.rs b/src/client_actor/message.rs new file mode 100644 index 0000000..29a51ec --- /dev/null +++ b/src/client_actor/message.rs @@ -0,0 +1,78 @@ +use std::path::PathBuf; + +use http::Uri; +use tokio::sync::oneshot; + +use super::error::ClientActorError; + + +type ActionIndex = u64; +type DownloadResult = Result, ClientActorError>; + + +#[derive(Clone, Debug)] +pub enum DownloadState { + GotHead, + #[allow(dead_code)] + Partial { content_type: Option }, + Done { content_type: Option }, +} + +#[derive(Debug)] +pub(super) enum ClientActorMessage { + Download { + filename: PathBuf, + uri: Uri, + respond_to: oneshot::Sender, + }, + GetData { + uri: Uri, + respond_to: oneshot::Sender>>, + }, +} + +#[derive(Clone, Debug)] +pub(super) enum ClientActorMessageHandle { + Download { + filename: PathBuf, + uri: Uri, + state: Option, + message: ActionIndex, + }, + GetData { + uri: Uri, + buffer: Option>, + message: ActionIndex, + }, +} + + +impl ClientActorMessageHandle { + pub(super) fn set_state(&mut self, new_state: DownloadState) { + match self { + Self::Download { ref mut state, .. } => *state = Some(new_state), + _ => panic!("Called with invalid variant"), + }; + } + + pub(super) fn filename(&self) -> PathBuf { + match self { + Self::Download { ref filename, .. } => filename.clone(), + _ => panic!("Called with invalid variant"), + } + } + + pub(super) fn uri(&self) -> Uri { + match self { + Self::Download { ref uri, .. } => uri.clone(), + Self::GetData { ref uri, .. } => uri.clone(), + } + } + + pub(super) fn buffer_mut(&mut self) -> &mut Option> { + match self { + Self::GetData { ref mut buffer, .. } => buffer, + _ => panic!("Called with invalid variant"), + } + } +} diff --git a/src/client_actor/util.rs b/src/client_actor/util.rs new file mode 100644 index 0000000..facbf02 --- /dev/null +++ b/src/client_actor/util.rs @@ -0,0 +1,118 @@ +use http::HeaderMap; +use log::{error, info}; +use tokio::select; + +use super::{ + error::ClientActorError, + message::{ClientActorMessage, ClientActorMessageHandle}, + ClientActor, + ClientTaskResult, +}; + + +async fn process_next_result(mut actor: ClientActor, result: ClientTaskResult) -> ClientActor { + use ClientActorMessageHandle::{Download, GetData}; + + match result { + Err(e) => { + info!("Retry failed download: {:?}", e); + // retry ... instead of responing here we could also respond + // with something that in turn would be used to retry... + let mut client = actor.client.clone(); + actor.tasks.spawn(async move { + match e.action { + Download { .. } => { + client.download( e.action.filename() + , &e.action.uri() + , &HeaderMap::new() ).await + . map_err(|source| ClientActorError { action: e.action.clone() + , source })?; + Ok(Some(e.action)) + }, + GetData { .. } => { + client.data(&e.action.uri(), &HeaderMap::new()).await + . map_err(|source| ClientActorError { action: e.action.clone() + , source })?; + Ok(Some(e.action)) + }, + } + }); + }, + + // when the task finishes + Ok(Some(action)) => { + match action { + Download { ref uri, ref state, ref message, .. } => { + info!("Done download: {:?}", uri); + if let Some((_, message)) = actor.actions.remove_entry(message) { + use ClientActorMessage::Download; + + match message { + Download { respond_to, .. } => { + let _ = respond_to.send(Ok(state.clone())); + }, + _ => panic!("Wrong variant ... this should never happen"), + } + } else { + panic!("Lost a message"); + } + }, + + GetData { ref uri, ref buffer, ref message } => { + info!("Done get_data: {:?}", uri); + if let Some((_, message)) = actor.actions.remove_entry(message) { + use ClientActorMessage::GetData; + match message { + GetData { respond_to, .. } => { + let _ = respond_to.send(buffer.clone()); + }, + _ => panic!("Wrong variant ... this should never happen"), + } + } else { + panic!("Lost a message"); + } + }, + } + }, + + // Got a stop message...here we still continue procession until the + // JoinSet is empty. + Ok(None) => (), + }; + + actor +} + +pub(super) async fn run_client(mut actor: ClientActor) { + loop { + if actor.tasks.len() >= actor.max_tasks { + if let Some(join) = actor.tasks.join_next().await { + match join { + Err(e) => { + error!("FATAL Join failed: {}", e); + break + }, + Ok(result) => actor = process_next_result(actor, result).await, + } + }; + } else { + select! { + Some(join) = actor.tasks.join_next() => { + match join { + Err(e) => { + error!("FATAL Join failed: {}", e); + break + }, + Ok(result) => actor = process_next_result(actor, result).await, + } + } + + Some(message) = actor.receiver.recv() => { + actor.handle_message(message).await; + } + + else => {} + } + } + } +} diff --git a/src/client_new.rs b/src/client_new.rs new file mode 100644 index 0000000..db7b208 --- /dev/null +++ b/src/client_new.rs @@ -0,0 +1,220 @@ +mod error; +mod util; + +use std::{path::Path, time::Duration}; + +use anyhow::anyhow; +use clap::{crate_name, crate_version}; +use error::RequestError; +use futures_util::StreamExt as _; +use http::{ + header::{CONTENT_LENGTH, CONTENT_TYPE, ORIGIN, RANGE, USER_AGENT}, request::Builder as RequestBuilder, HeaderMap, HeaderValue, Request, Response, Uri +}; +use http_body_util::BodyDataStream; +use log::debug; +use reqwest::{redirect::Policy, Body}; +use tokio::{ + fs::File, + io::AsyncWriteExt as _, + time::timeout +}; +use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt as _}; +use tower_http::decompression::{DecompressionBody, DecompressionLayer}; +use tower_http_client::{client::BodyReader, ServiceExt as _}; +use tower_reqwest::HttpClientLayer; + + +type ClientBody = DecompressionBody; +type ClientResponse = Response; +type ClientResponseResult = Result; +type HttpClient = BoxCloneService, ClientResponse, anyhow::Error>; + + +#[derive(Clone, Debug)] +pub(super) struct Client { + client: HttpClient, + default_headers: HeaderMap, + body_timeout: Option, +} + +impl Client { + pub(super) fn new( buffer: usize + , rate_limit: u64 + , concurrency_limit: usize + , timeout: Duration ) -> anyhow::Result { + let client = ServiceBuilder::new() + // Add some layers. + . buffer(buffer) + . rate_limit(rate_limit, Duration::from_secs(1)) + . concurrency_limit(concurrency_limit) + . timeout(timeout) + . layer(DecompressionLayer::new()) + // Make client compatible with the `tower-http` layers. + . layer(HttpClientLayer) + . service( reqwest::Client::builder() + . redirect(Policy::limited(5)) + . build()? ) + . map_err(anyhow::Error::msg) + . boxed_clone(); + + let body_timeout = None; + let mut default_headers = HeaderMap::new(); + default_headers.insert( + USER_AGENT, + HeaderValue::from_str(&( crate_name!().to_string() + "/" + + crate_version!() )).unwrap() ); + + Ok(Self {client, default_headers, body_timeout}) + } + + fn set_body_timeout(mut self, timeout: Option) -> Self { + self.body_timeout = timeout; + self + } + + fn set_origin(mut self, origin: Option) -> Self { + if let Some(origin) = origin { + self.default_headers.insert( + ORIGIN, + HeaderValue::from_str(&origin).unwrap() ); + } else { + self.default_headers.remove(ORIGIN); + } + self + } + + fn set_user_agent(mut self, user_agent: Option) -> Self { + if let Some(user_agent) = user_agent { + self.default_headers.insert( + USER_AGENT, + HeaderValue::from_str(&user_agent).unwrap() ); + } else { + self.default_headers.remove(USER_AGENT); + } + self + } + + pub(super) async fn data( &mut self + , uri: &Uri + , headers: &HeaderMap ) + -> anyhow::Result> { + + let mut response = self.request("GET", uri, headers).await?; + + // read body into Vec + let body = BodyReader::new(response.body_mut()) + . bytes() + . await + . map_err(|e| anyhow!(e))? + . to_vec(); + + Ok(body) + } + + pub(super) async fn download( &mut self + , filename: impl AsRef + , uri: &Uri + , headers: &HeaderMap ) + -> anyhow::Result> { + + let filename = filename.as_ref(); + + // - get all informations to eventually existing file + let mut from = util::file_size(&filename).await; + + // - get infos to uri + let response_headers = &self.head(uri, headers).await?; + let content_length = util::get_header::( response_headers + , CONTENT_LENGTH ); + let content_type = util::get_header::( response_headers + , CONTENT_TYPE ) + . or(Some("unknown".into())); + + if let Some(content_length) = content_length { + if from != 0 && content_length - 1 <= from { + return Ok(None); + } + } else { + from = 0; + } + + // - do the neccessry request. + let headers = &mut headers.clone(); + headers.insert(RANGE, format!("bytes={}-", from).parse().unwrap()); + + let mut response = self.request("GET", uri, headers).await?; + + // - open or create file + let file = util::open_or_create(&response.status(), &filename).await; + // - download Data + let size = self.clone().store_body(file, response.body_mut()).await?; + + Ok(content_type.map(|c| (c, size))) + } + + async fn head( &mut self + , uri: &Uri + , headers: &HeaderMap ) -> Result { + Ok( self.request("HEAD", uri, headers) + . await? + . headers() + . clone() ) + } + + async fn request( &mut self + , method: &str + , uri: &Uri + , headers: &HeaderMap ) -> ClientResponseResult { + let mut request = RequestBuilder::new() + . method(method) + . uri(uri) + . body(Body::default()) + . map_err(|e| RequestError::new(None, Some(e.into())))?; + + request.headers_mut().extend(headers.clone()); + + debug!("Request: {:?}", request); + + match self.client.execute(request).await { + Err(e) => Err(RequestError::new(None, Some(e))), + Ok(response) => { + debug!("Response: {:?}", response.headers()); + + if response.status().is_success() { + Ok(response) + } else { + Err(RequestError::new(Some(response.map(|_| ())), None)) + } + }, + } + } + + async fn store_body( self + , mut file: File + , body: &mut ClientBody ) -> anyhow::Result { + let mut body = BodyDataStream::new(body); + let mut written = 0; + + loop { + let data_future = body.next(); + let data = if let Some(io_timeout) = self.body_timeout { + // give timeout somehow... probably from client. + timeout(io_timeout, data_future).await? + } else { + data_future.await + }; + + match data { + None => break, + Some(Err(e)) => Err(anyhow!(e))?, + Some(Ok(data)) => { + written += data.len(); + file . write_all(&data).await?; + file . flush().await?; + }, + } + }; + + Ok(written) + } +} diff --git a/src/client_new/error.rs b/src/client_new/error.rs new file mode 100644 index 0000000..71e4eca --- /dev/null +++ b/src/client_new/error.rs @@ -0,0 +1,35 @@ +use std::{error, fmt}; + +use http::Response; + + +#[derive(Debug)] +pub(super) struct RequestError { + pub(super) response: Option>, + pub(super) source: Option, +} + +impl RequestError { + pub(super) fn new( response: Option> + , source: Option ) -> Self { + Self { response, source } + } +} + +impl error::Error for RequestError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match &self.source { + None => None, + Some(e) => Some(e.as_ref()), + } + } +} + +impl fmt::Display for RequestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.source { + None => write!(f, "request error: {:?}", self.response), + Some(err) => write!(f, "request error ({:?}): {}", self.response, err), + } + } +} diff --git a/src/client_new/util.rs b/src/client_new/util.rs new file mode 100644 index 0000000..be818ce --- /dev/null +++ b/src/client_new/util.rs @@ -0,0 +1,47 @@ +use std::{io::ErrorKind, path::Path, str::FromStr}; + +use http::{HeaderMap, HeaderName, StatusCode}; +use tokio::fs::{symlink_metadata, File}; + + +pub(super) async fn file_size(filename: impl AsRef) -> u64 { + // - get all informations to eventually existing file + let filename = filename.as_ref(); + let metadata = match symlink_metadata(filename).await { + Ok(metadata) => Some(metadata), + Err(error) => match error.kind() { + // If we can't write to a file we need to ... well theres nothing we can do + ErrorKind::PermissionDenied => panic!("Permission denied on: {:?}", filename), + _ => None, + } + }; + + metadata.map_or(0, |m| m.len()) +} + +pub(super) fn get_header( headers: &HeaderMap + , name: HeaderName) -> Option { + headers . get(name) + . and_then(|h| h.to_str().ok()) + . and_then(|h| h.parse::().ok()) +} + +pub(super) async fn open_or_create( status: &StatusCode + , filename: impl AsRef) -> File { + match status { + &StatusCode::PARTIAL_CONTENT => + // Here we assume that this response only comes if the requested + // range can be fullfilled and thus is the data range in the + // response. Thats why I do not check the content-range header. + // If that assumption does not hold this needs to be fixed. + File::options() . create(true) + . append(true) + . open(filename) + . await + . expect("can not create file for writing"), + + _ => + File::create(filename) . await + . expect("can not create file for writing"), + } +} diff --git a/src/main.rs b/src/main.rs index e4f97b9..7279d88 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,8 @@ mod client; mod process; mod m3u8_download; +mod client_new; +mod client_actor; use std::{ ffi::OsStr,