diff --git a/src/dns_cache.rs b/src/dns_cache.rs index fcd09f5..4a5d140 100644 --- a/src/dns_cache.rs +++ b/src/dns_cache.rs @@ -7,13 +7,13 @@ use std::future::Future; use std::net::{SocketAddr, IpAddr}; use std::sync::Arc; -use std::io::Error; +use std::io::{Error, ErrorKind}; use std::pin::Pin; use std::task::{self, Poll}; use hyper::client::connect::dns::Name; -use log::{trace, debug}; +use log::{trace, debug, error}; use lru::LruCache; @@ -24,6 +24,23 @@ use tokio::task::JoinHandle; use trust_dns_resolver::TokioAsyncResolver; use trust_dns_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts}; +use trust_dns_resolver::error::{ResolveError}; + +#[derive(Debug)] +pub enum DnsResolveError { + ResolutionFailure (ResolveError), + DnsNoRecordsFound, + DnsUnknownError, + InterruptedError (Error) +} + +impl std::fmt::Display for DnsResolveError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl std::error::Error for DnsResolveError {} pub async fn create_resolver(dns1_sock: SocketAddr, dns2_sock: SocketAddr) -> TokioAsyncResolver { let mut resolver_config: ResolverConfig = ResolverConfig::new(); @@ -53,21 +70,38 @@ pub async fn create_resolver(dns1_sock: SocketAddr, dns2_sock: SocketAddr) -> To pub async fn resolve_with_cache(host: &str, resolver: &TokioAsyncResolver, - resolver_cache: Arc>>) -> String { + resolver_cache: Arc>>) -> Result { let host_string = String::from(host); let mut guard = resolver_cache.lock().await; let found = (*guard).get(&host_string); if found.is_none() { trace!("resolve_with_cache: host={} not in cache, resolving...", String::from(host_string.as_str())); - let resolved_ip = format!("{}", resolver.lookup_ip(host).await.unwrap().iter().next().unwrap()); - (*guard).put(String::from(host_string.as_str()), resolved_ip.to_string()); - debug!("resolve_with_cache: resolved {} -> {}", String::from(host_string.as_str()), &resolved_ip); - resolved_ip + let lookup_result = resolver.lookup_ip(host).await; + if lookup_result.is_err() { + let err = lookup_result.err(); + if err.is_some() { + Err(DnsResolveError::ResolutionFailure(err.unwrap())) + } else { + Err(DnsResolveError::DnsUnknownError) + } + } else { + let ip_result = lookup_result.unwrap(); + let first_result = ip_result.iter().next(); + if first_result.is_none() { + error!("resolve_with_cache: {} - no records found", String::from(host_string.as_str())); + Err(DnsResolveError::DnsNoRecordsFound) + } else { + let resolved_ip = format!("{}", first_result.unwrap()); + (*guard).put(String::from(host_string.as_str()), resolved_ip.to_string()); + debug!("resolve_with_cache: resolved {} -> {}", String::from(host_string.as_str()), &resolved_ip); + Ok(resolved_ip) + } + } } else { let found = found.unwrap(); trace!("resolve_with_cache: host={} found in cache, returning: {}", host_string, found); - String::from(found) + Ok(String::from(found)) } } @@ -92,24 +126,29 @@ pub struct CacheAddrs { } pub struct CacheFuture { - inner: JoinHandle> + inner: JoinHandle> } pub async fn resolve_to_result(host: String, resolver: Arc, - cache: Arc>>) -> Result { - let ip = resolve_with_cache(host.as_str(), &resolver, cache).await; - let ip_addr: IpAddr = ip.parse().unwrap(); - let sock = SocketAddr::new(ip_addr, 0); - Ok(IpAddrs { iter: vec![sock].into_iter() }) + cache: Arc>>) -> Result { + let resolve_result = resolve_with_cache(host.as_str(), &resolver, cache).await; + if resolve_result.is_err() { + Err(resolve_result.err().unwrap()) + } else { + let ip = resolve_result.unwrap(); + let ip_addr: IpAddr = ip.parse().unwrap(); + let sock = SocketAddr::new(ip_addr, 0); + Ok(IpAddrs { iter: vec![sock].into_iter() }) + } } impl Service for CacheResolver { type Response = CacheAddrs; - type Error = std::io::Error; + type Error = DnsResolveError; type Future = CacheFuture; - fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll> { + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll> { Poll::Ready(Ok(())) } @@ -119,14 +158,13 @@ impl Service for CacheResolver { let cache: Arc>> = self._cache.clone(); let addrs = tokio::task::spawn( resolve_to_result(String::from(name.as_str()), resolver, cache) - // resolve_with_cache(host.as_str(), &resolver, cache) ); CacheFuture { inner: addrs } } } impl Future for CacheFuture { - type Output = Result; + type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { Pin::new(&mut self.inner).poll(cx).map(|res| match res { @@ -134,7 +172,7 @@ impl Future for CacheFuture { Ok(Err(err)) => Err(err), Err(join_err) => { if join_err.is_cancelled() { - Err(std::io::Error::new(std::io::ErrorKind::Interrupted, join_err)) + Err(DnsResolveError::InterruptedError(Error::new(ErrorKind::Interrupted, join_err))) } else { panic!("gai background task failed: {:?}", join_err) } diff --git a/src/proxy.rs b/src/proxy.rs index ef93f97..db3b64f 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -126,11 +126,23 @@ async fn proxy(client: Client>>, bad_request("(remove) invalid ping hash\n") } else { let routes = remove_routes.routes.clone(); + let mut resolve_errors: Vec<(String, DnsResolveError)> = Vec::new(); for route in routes.into_iter() { - let ip_string = resolve_with_cache(route.as_str(), &resolver, resolver_cache.clone()).await; - remove_static_route(&ip_string); + let resolve_result = resolve_with_cache(route.as_str(), &resolver, resolver_cache.clone()).await; + if resolve_result.is_err() { + let err = resolve_result.err().unwrap(); + error!("proxy(remove): error resolving hostname {:?}: {:?}", route.clone(), err); + resolve_errors.push((route.clone(), err)); + } else { + let ip_string = resolve_result.unwrap(); + remove_static_route(&ip_string); + } + } + if resolve_errors.is_empty() { + Ok(Response::new(Body::from(format!("Removed: {:?}", remove_routes.routes)))) + } else { + bad_request(format!("(remove) resolution errors: {:?}\n", resolve_errors).as_str()) } - Ok(Response::new(Body::from(format!("Removed: {:?}", remove_routes.routes)))) } } else if path.eq(PATH_HEALTH) && method == Method::GET { @@ -143,9 +155,15 @@ async fn proxy(client: Client>>, } let host = host.unwrap(); - let host_string = String::from(host); - trace!("proxy: received request for host {:?}, resolving...", host_string); - let ip_string = resolve_with_cache(host, &resolver, resolver_cache).await; + let host_string = Arc::new(String::from(host)); + trace!("proxy: received request for host {:?}, resolving...", host_string.clone()); + let resolve_result = resolve_with_cache(host, &resolver, resolver_cache).await; + if resolve_result.is_err() { + let err = resolve_result.err().unwrap(); + error!("proxy: error resolving hostname {:?}: {:?}", host_string.clone(), err); + return bad_request(format!("Error: error resolving hostname: {:?}: {:?}\n", host_string.clone(), err).as_str()); + } + let ip_string = resolve_result.unwrap(); info!("proxy: host {} resolved to: {}", host, ip_string); trace!("proxy: request is {:?}", req);