From d008540a8b4cde1e7a335238cf74c08f98182384 Mon Sep 17 00:00:00 2001 From: Georg Hopp Date: Tue, 14 Jan 2025 13:13:40 +0100 Subject: [PATCH] improved Errors for Client and ClientActorHandle --- Cargo.lock | 1 + Cargo.toml | 1 + src/client.rs | 188 ++++++++++++++++------------- src/client/error.rs | 65 +--------- src/client/error/download_error.rs | 83 +++++++++++++ src/client/error/request_error.rs | 88 ++++++++++++++ src/client_actor.rs | 9 +- src/client_actor/error.rs | 30 +++-- src/client_actor/message.rs | 7 +- src/client_actor/util.rs | 8 +- src/m3u8_download.rs | 3 +- src/main.rs | 4 +- 12 files changed, 324 insertions(+), 163 deletions(-) create mode 100644 src/client/error/download_error.rs create mode 100644 src/client/error/request_error.rs diff --git a/Cargo.lock b/Cargo.lock index bf43a20..27fed9b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -512,6 +512,7 @@ name = "hlsclient" version = "0.2.3" dependencies = [ "anyhow", + "bytes", "clap", "env_logger", "futures-util", diff --git a/Cargo.toml b/Cargo.toml index c706927..3ceee2f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] anyhow = "1.0" +bytes = "1.9" clap = { version = "4.5", features = [ "derive", "cargo" ] } env_logger = "0.11" futures-util = "0.3" diff --git a/src/client.rs b/src/client.rs index 33c7de8..2a0e06c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,29 +1,25 @@ -mod error; +pub(super) mod error; mod util; use std::{path::Path, time::Duration}; -use anyhow::anyhow; +use bytes::Bytes; use clap::{crate_name, crate_version}; -use error::{DownloadError, RequestError}; use futures_util::StreamExt as _; use http::{ - header::{CONTENT_LENGTH, CONTENT_TYPE, ORIGIN, RANGE, USER_AGENT}, - request::Builder as RequestBuilder, + header, + request, HeaderMap, HeaderValue, + Method, Request, Response, - Uri + Uri, }; use http_body_util::BodyDataStream; use log::debug; use reqwest::{redirect::Policy, Body}; -use tokio::{ - fs::File, - io::AsyncWriteExt as _, - time::timeout -}; +use tokio::{fs::File, io::AsyncWriteExt as _, time::timeout}; use tower::{util::BoxCloneService, ServiceBuilder, ServiceExt as _}; use tower_http::decompression::{DecompressionBody, DecompressionLayer}; use tower_http_client::{client::BodyReader, ServiceExt as _}; @@ -32,8 +28,12 @@ use tower_reqwest::HttpClientLayer; type ClientBody = DecompressionBody; type ClientResponse = Response; -type ClientResponseResult = Result; type HttpClient = BoxCloneService, ClientResponse, anyhow::Error>; +pub(super) type DataResult = error::DownloadResult; +pub(super) type DownloadResult = error::DownloadResult; +type HeadResult = error::ClientRequestResult; +type InvalidHeaderResult = Result; +type RequestResult = error::ClientRequestResult; #[derive(Clone, Debug)] @@ -57,7 +57,7 @@ impl Client { pub(super) fn new( buffer: usize , rate_limit: u64 , concurrency_limit: usize - , timeout: Duration ) -> anyhow::Result { + , timeout: Duration ) -> Result { let client = ServiceBuilder::new() // Add some layers. . buffer(buffer) @@ -76,7 +76,7 @@ impl Client { let body_timeout = None; let mut default_headers = HeaderMap::new(); default_headers.insert( - USER_AGENT, + header::USER_AGENT, HeaderValue::from_str(&( crate_name!().to_string() + "/" + crate_version!() )).unwrap() ); @@ -88,60 +88,51 @@ impl Client { self } - pub(super) fn set_origin(mut self, origin: Option) -> Self { - if let Some(origin) = origin { - self.default_headers.insert( - ORIGIN, - HeaderValue::from_str(&origin).unwrap() ); - } else { - self.default_headers.remove(ORIGIN); - } - self + pub(super) fn set_origin( mut self, origin: Option) + -> InvalidHeaderResult { + + match origin.as_deref() { + Some(origin) => + self.default_headers.insert( header::ORIGIN + , HeaderValue::from_str(origin)? ), + None => self.default_headers.remove(header::ORIGIN), + }; + Ok(self) } - pub(super) fn set_user_agent(mut self, user_agent: Option) -> Self { - if let Some(user_agent) = user_agent { - self.default_headers.insert( - USER_AGENT, - HeaderValue::from_str(&user_agent).unwrap() ); - } else { - self.default_headers.remove(USER_AGENT); - } - self + pub(super) fn set_user_agent(mut self, user_agent: Option) + -> InvalidHeaderResult { + + match user_agent.as_deref() { + Some(user_agent) => + self.default_headers.insert( + header::USER_AGENT, + HeaderValue::from_str(user_agent)? ), + None => self.default_headers.remove(header::USER_AGENT), + }; + Ok(self) } pub(super) async fn data( &mut self , uri: &Uri - , headers: &HeaderMap ) - -> anyhow::Result> { - - let mut response = self.request("GET", uri, headers).await?; - - // read body into Vec - let body = BodyReader::new(response.body_mut()) - . bytes() - . await - . map_err(|e| anyhow!(e))? - . to_vec(); - - Ok(body) + , headers: &HeaderMap ) -> DataResult { + let mut response = self.request(Method::GET, uri, headers).await?; + Ok(BodyReader::new(response.body_mut()).bytes().await?) } pub(super) async fn download( &mut self , filename: impl AsRef , uri: &Uri - , headers: &HeaderMap ) - -> anyhow::Result { - + , headers: &HeaderMap ) -> DownloadResult { // - get all informations to eventually existing file let mut from = util::file_size(&filename).await; // - get infos to uri let response_headers = &self.head(uri, headers).await?; let content_length = util::get_header::( response_headers - , CONTENT_LENGTH ); + , header::CONTENT_LENGTH ); let content_type = util::get_header::( response_headers - , CONTENT_TYPE ) + , header::CONTENT_TYPE ) . or(Some("unknown".into())); if let Some(content_length) = content_length { @@ -154,51 +145,68 @@ impl Client { // - do the neccessry request. let headers = &mut headers.clone(); - headers.insert(RANGE, format!("bytes={}-", from).parse().unwrap()); + headers.insert( header::RANGE + , format!("bytes={}-", from).parse().unwrap() ); - let mut response = self.request("GET", uri, headers).await?; + let mut response = self.request(Method::GET, uri, headers).await?; // - open or create file let file = util::open_or_create(&response.status(), &filename).await; // - download Data - Ok( self.clone().store_body( file - , from as usize - , content_type - , response.body_mut() ).await? ) + self.clone().store_body( file + , from as usize + , content_type + , response.body_mut() ).await } async fn head( &mut self , uri: &Uri - , headers: &HeaderMap ) -> Result { - Ok( self.request("HEAD", uri, headers) + , headers: &HeaderMap ) -> HeadResult { + Ok( self.request(Method::HEAD, uri, headers) . await? . headers() . clone() ) } async fn request( &mut self - , method: &str + , method: Method , uri: &Uri - , headers: &HeaderMap ) -> ClientResponseResult { - let mut request = RequestBuilder::new() - . method(method) - . uri(uri) - . body(Body::default()) - . map_err(|e| RequestError::new(None, Some(e.into())))?; - - request.headers_mut().extend(headers.clone()); - - debug!("Request: {:?}", request); - + , headers: &HeaderMap ) -> RequestResult { + let (mut request_parts, _) = request::Builder::new() + . method(method) + . uri(uri) + . body(())? + . into_parts(); + request_parts.headers = headers.to_owned(); + debug!("Request: {:?}", request_parts); + + let request = Request::from_parts( request_parts.clone() + , Body::default() ); match self.client.execute(request).await { - Err(e) => Err(RequestError::new(None, Some(e))), + Err(e) => { + let request_parts = Some(request_parts.clone()); + let e = Some(e); + Err(error::ClientRequestError::new( request_parts + , None + , None + , e ))? + }, + Ok(response) => { - debug!("Response: {:?}", response.headers()); + let (response_parts, response_body) = response.into_parts(); + debug!("Response: {:?}", response_parts); - if response.status().is_success() { - Ok(response) + if response_parts.status.is_success() { + Ok(Response::from_parts(response_parts, response_body)) } else { - Err(RequestError::new(Some(response.map(|_| ())), None)) + let request = Some(request_parts); + let response = Some(response_parts); + let response_body = + Some(BodyReader::new(response_body).bytes().await?); + Err(error::ClientRequestError::new( request + , response + , response_body + , None ))? } }, } @@ -208,30 +216,40 @@ impl Client { , mut file: File , mut size: usize , content_type: Option - , body: &mut ClientBody ) -> Result { + , body: &mut ClientBody ) -> DownloadResult { let mut body = BodyDataStream::new(body); - let mut state = DownloadState::Partial { content_type: content_type.clone(), size }; + let mut state = DownloadState::Partial { + content_type: content_type.clone(), + size + }; loop { let data_future = body.next(); let data = if let Some(io_timeout) = self.body_timeout { // give timeout somehow... probably from client. - timeout(io_timeout, data_future).await - . map_err(|e| DownloadError::new(state.clone(), e.into()))? + timeout(io_timeout, data_future).await.map_err(|e| { + error::DownloadError::from(e).set_state(&state) + })? } else { data_future.await }; match data { None => break, - Some(Err(e)) => Err(DownloadError::new(state.clone(), anyhow!(e)))?, + Some(Err(e)) => + Err(error::DownloadError::from(e).set_state(&state))?, Some(Ok(data)) => { size += data.len(); - state = DownloadState::Partial { content_type: content_type.clone(), size }; - file . write_all(&data).await - . map_err(|e| DownloadError::new(state.clone(), e.into()))?; - file . flush().await - . map_err(|e| DownloadError::new(state.clone(), e.into()))?; + state = DownloadState::Partial { + content_type: content_type.clone(), + size + }; + file.write_all(&data).await.map_err(|e| { + error::DownloadError::from(e).set_state(&state) + })?; + file.flush().await.map_err(|e| { + error::DownloadError::from(e).set_state(&state) + })?; }, } }; diff --git a/src/client/error.rs b/src/client/error.rs index b3d8e34..547c0e7 100644 --- a/src/client/error.rs +++ b/src/client/error.rs @@ -1,63 +1,10 @@ -use std::{error, fmt}; +mod request_error; +mod download_error; -use http::Response; -use super::DownloadState; +pub(crate) use request_error::ClientRequestError; +pub(crate) use download_error::DownloadError; -#[derive(Debug)] -pub(super) struct DownloadError { - pub(super) state: DownloadState, - pub(super) source: anyhow::Error, -} - -#[derive(Debug)] -pub(super) struct RequestError { - pub(super) response: Option>, - pub(super) source: Option, -} - - -impl DownloadError { - pub(super) fn new( state: DownloadState - , source: anyhow::Error ) -> Self { - Self { state, source } - } -} - -impl error::Error for DownloadError { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - Some(self.source.as_ref()) - } -} - -impl fmt::Display for DownloadError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "download error ({:?}): {}", self.state, self.source) - } -} - -impl RequestError { - pub(super) fn new( response: Option> - , source: Option ) -> Self { - Self { response, source } - } -} - -impl error::Error for RequestError { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - match &self.source { - None => None, - Some(e) => Some(e.as_ref()), - } - } -} - -impl fmt::Display for RequestError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.source { - None => write!(f, "request error: {:?}", self.response), - Some(err) => write!(f, "request error ({:?}): {}", self.response, err), - } - } -} +pub(super) type ClientRequestResult = Result; +pub(super) type DownloadResult = Result; diff --git a/src/client/error/download_error.rs b/src/client/error/download_error.rs new file mode 100644 index 0000000..a798dbe --- /dev/null +++ b/src/client/error/download_error.rs @@ -0,0 +1,83 @@ +use std::{error, fmt, io}; + +use tokio::time::error::Elapsed; + +use crate::client::DownloadState; + +use super::ClientRequestError; + + +#[derive(Debug)] +pub(crate) enum DownloadErrorSource { + #[allow(dead_code)] + Request(ClientRequestError), + Timeout(Elapsed), + #[allow(dead_code)] + IoError(io::Error), + #[allow(dead_code)] + Other(anyhow::Error), +} + +#[derive(Debug)] +pub(crate) struct DownloadError { + state: DownloadState, + source: Option, +} + + +impl DownloadError { + pub(crate) fn new( state: DownloadState + , source: Option ) + -> Self { + + Self { state, source } + } + + pub(crate) fn set_state(mut self, state: &DownloadState) -> Self { + self.state = state.to_owned(); + self + } +} + +impl error::Error for DownloadError {} + +impl fmt::Display for DownloadError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( f + , "download error: {{ state: {:?} / source: {:?} }}" + , self.state + , self.source ) + } +} + +impl From for DownloadErrorSource { + fn from(value: ClientRequestError) -> Self { + Self::Request(value) + } +} + +impl From for DownloadErrorSource { + fn from(value: Elapsed) -> Self { + Self::Timeout(value) + } +} + +impl From for DownloadErrorSource { + fn from(value: io::Error) -> Self { + Self::IoError(value) + } +} + +impl From> for DownloadErrorSource { + fn from(value: Box) -> Self { + Self::Other(anyhow::anyhow!(value)) + } +} + +impl> From for DownloadError { + fn from(value: T) -> Self { + let state = DownloadState::None; + let source = Some(value.into()); + Self::new(state, source) + } +} diff --git a/src/client/error/request_error.rs b/src/client/error/request_error.rs new file mode 100644 index 0000000..ae7e89c --- /dev/null +++ b/src/client/error/request_error.rs @@ -0,0 +1,88 @@ +use std::{error, fmt}; + +use bytes::Bytes; +use http::{request, response}; + + +#[derive(Debug)] +pub(crate) struct ClientRequestError { + #[allow(dead_code)] + request: Option, + response: Option, + #[allow(dead_code)] + response_body: Option, + source: Option, +} + + +impl ClientRequestError { + pub(crate) fn new( request: Option + , response: Option + , response_body: Option + , source: Option) -> Self { + Self { request, response, response_body, source } + } + + #[allow(dead_code)] + pub(crate) fn request(&self) -> Option<&request::Parts> { + self.request.as_ref() + } + + #[allow(dead_code)] + pub(crate) fn response(&self) -> Option<&response::Parts> { + self.response.as_ref() + } + + #[allow(dead_code)] + pub(crate) fn response_body(&self) -> Option<&Bytes> { + self.response_body.as_ref() + } +} + +impl error::Error for ClientRequestError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + self.source.as_ref().map(|e| e.as_ref()) + } +} + +impl fmt::Display for ClientRequestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( f + , "request error: {{ + request: {:?} / response: {:?} / source: {:?} + }}" + , self.request + , self.response + , self.source ) + } +} + +impl From for ClientRequestError { + fn from(value: http::Error) -> Self { + let request = None; + let response = None; + let response_body = None; + let source = Some(value.into()); + Self { request, response, response_body, source } + } +} + +impl From for ClientRequestError { + fn from(value: anyhow::Error) -> Self { + let request = None; + let response = None; + let response_body = None; + let source = Some(value); + Self { request, response, response_body, source } + } +} + +impl From> for ClientRequestError { + fn from(value: Box) -> Self { + let request = None; + let response = None; + let response_body = None; + let source = Some(anyhow::anyhow!(value)); + Self { request, response, response_body, source } + } +} diff --git a/src/client_actor.rs b/src/client_actor.rs index b8f59ac..8ff0ce9 100644 --- a/src/client_actor.rs +++ b/src/client_actor.rs @@ -4,6 +4,7 @@ mod util; use std::{collections::HashMap, path::Path}; +use bytes::Bytes; use error::ClientActorError; use http::{HeaderMap, Uri}; use message::{ClientActorMessage, ClientActorMessageHandle}; @@ -80,7 +81,8 @@ impl ClientActor { self.tasks.spawn(async move { let result = client.download(handle.filename(), &handle.uri(), &HeaderMap::new()).await; match result { - Err(source) => Err(ClientActorError::new(&handle, source)), + Err(source) => + Err(ClientActorError::new(&handle, source.into())), Ok(state) => { handle.set_state(state); Ok(Some(handle)) @@ -102,7 +104,8 @@ impl ClientActor { self.tasks.spawn(async move { let result = client.data(&handle.uri(), &HeaderMap::new()).await; match result { - Err(source) => Err(ClientActorError::new(&handle, source)), + Err(source) => + Err(ClientActorError::new(&handle, source.into())), Ok(data) => { *handle.buffer_mut() = Some(data); Ok(Some(handle)) @@ -148,7 +151,7 @@ impl ClientActorHandle { receive.await.expect("Actor cancelled unexpected") } - pub(super) async fn body_bytes(&self, uri: &Uri) -> Option> { + pub(super) async fn body_bytes(&self, uri: &Uri) -> Option { let uri = uri.to_owned(); let (send, receive) = oneshot::channel(); let msg = ClientActorMessage::GetData { uri, respond_to: send }; diff --git a/src/client_actor/error.rs b/src/client_actor/error.rs index 9305ae1..e4b9e84 100644 --- a/src/client_actor/error.rs +++ b/src/client_actor/error.rs @@ -1,30 +1,44 @@ use std::{error, fmt}; +use crate::client::error as client_error; + use super::message::ClientActorMessageHandle; +#[derive(Debug)] +pub(crate) enum ClientActorErrorSource { + #[allow(dead_code)] + Download(client_error::DownloadError), +} + #[derive(Debug)] pub(crate) struct ClientActorError { pub(super) action: ClientActorMessageHandle, - pub(super) source: anyhow::Error, + pub(super) source: Option, +} + + +impl From for Option { + fn from(value: client_error::DownloadError) -> Self { + Some(ClientActorErrorSource::Download(value)) + } } impl ClientActorError { pub(super) fn new( action: &ClientActorMessageHandle - , source: anyhow::Error ) -> Self { + , source: Option ) -> Self { let action = action.to_owned(); Self { action, source } } } -impl error::Error for ClientActorError { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - Some(self.source.as_ref()) - } -} +impl error::Error for ClientActorError {} impl fmt::Display for ClientActorError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "download error ({:?}): {}", self.action, self.source) + write!( f + , "client actor error: {{ action: {:?} / source: {:?} }}" + , self.action + , self.source ) } } diff --git a/src/client_actor/message.rs b/src/client_actor/message.rs index bc7e412..d188f73 100644 --- a/src/client_actor/message.rs +++ b/src/client_actor/message.rs @@ -1,5 +1,6 @@ use std::path::PathBuf; +use bytes::Bytes; use http::Uri; use tokio::sync::oneshot; @@ -18,7 +19,7 @@ pub(super) enum ClientActorMessage { }, GetData { uri: Uri, - respond_to: oneshot::Sender>>, + respond_to: oneshot::Sender>, }, } @@ -32,7 +33,7 @@ pub(super) enum ClientActorMessageHandle { }, GetData { uri: Uri, - buffer: Option>, + buffer: Option, message: ActionIndex, }, } @@ -60,7 +61,7 @@ impl ClientActorMessageHandle { } } - pub(super) fn buffer_mut(&mut self) -> &mut Option> { + pub(super) fn buffer_mut(&mut self) -> &mut Option { match self { Self::GetData { ref mut buffer, .. } => buffer, _ => panic!("Called with invalid variant"), diff --git a/src/client_actor/util.rs b/src/client_actor/util.rs index 7880596..5359948 100644 --- a/src/client_actor/util.rs +++ b/src/client_actor/util.rs @@ -26,7 +26,9 @@ async fn process_next_result(mut actor: ClientActor, result: ClientTaskResult) - , &e.action.uri() , &HeaderMap::new()).await; match result { - Err(source) => Err(ClientActorError::new(&e.action, source)), + Err(source) => + Err(ClientActorError::new( &e.action + , source.into() )), Ok(state) => { e.action.set_state(state); Ok(Some(e.action)) @@ -36,7 +38,9 @@ async fn process_next_result(mut actor: ClientActor, result: ClientTaskResult) - GetData { .. } => { let result = client.data(&e.action.uri(), &HeaderMap::new()).await; match result { - Err(source) => Err(ClientActorError::new(&e.action, source)), + Err(source) => + Err(ClientActorError::new( &e.action + , source.into() )), Ok(data) => { *e.action.buffer_mut() = Some(data); Ok(Some(e.action)) diff --git a/src/m3u8_download.rs b/src/m3u8_download.rs index c346274..3fca917 100644 --- a/src/m3u8_download.rs +++ b/src/m3u8_download.rs @@ -1,6 +1,7 @@ use std::path::{Path, PathBuf}; use anyhow::anyhow; +use bytes::Bytes; use futures_util::future::join_all; use http::{uri::{Authority, Scheme}, Uri}; use log::debug; @@ -60,7 +61,7 @@ impl TsPart { } impl M3u8Download { - pub(super) async fn new(m3u8_data: Vec, index_uri: Uri) -> anyhow::Result { + pub(super) async fn new(m3u8_data: Bytes, index_uri: Uri) -> anyhow::Result { let scheme = index_uri.scheme() . ok_or(anyhow!("Problem scheme in m3u8 uri"))? . to_owned(); diff --git a/src/main.rs b/src/main.rs index df51139..d88ac81 100644 --- a/src/main.rs +++ b/src/main.rs @@ -90,10 +90,10 @@ async fn main() -> anyhow::Result<()> { let client = Client::new(buffer, rate_limit, concurrency_limit, timeout)? . set_body_timeout(body_timeout) - . set_origin(args.origin); + . set_origin(args.origin)?; let client = if let Some(user_agent) = args.agent { - client.set_user_agent(Some(user_agent)) + client.set_user_agent(Some(user_agent))? } else { client };