瀏覽代碼

better handling of dns lookup errors

master
Jonathan Cobb 4 年之前
父節點
當前提交
77a5dc7ce6
共有 2 個檔案被更改,包括 81 行新增25 行删除
  1. +57
    -19
      src/dns_cache.rs
  2. +24
    -6
      src/proxy.rs

+ 57
- 19
src/dns_cache.rs 查看文件

@@ -7,13 +7,13 @@
use std::future::Future;
use std::net::{SocketAddr, IpAddr};
use std::sync::Arc;
use std::io::Error;
use std::io::{Error, ErrorKind};
use std::pin::Pin;
use std::task::{self, Poll};

use hyper::client::connect::dns::Name;

use log::{trace, debug};
use log::{trace, debug, error};

use lru::LruCache;

@@ -24,6 +24,23 @@ use tokio::task::JoinHandle;

use trust_dns_resolver::TokioAsyncResolver;
use trust_dns_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
use trust_dns_resolver::error::{ResolveError};

#[derive(Debug)]
pub enum DnsResolveError {
ResolutionFailure (ResolveError),
DnsNoRecordsFound,
DnsUnknownError,
InterruptedError (Error)
}

impl std::fmt::Display for DnsResolveError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}

impl std::error::Error for DnsResolveError {}

pub async fn create_resolver(dns1_sock: SocketAddr, dns2_sock: SocketAddr) -> TokioAsyncResolver {
let mut resolver_config: ResolverConfig = ResolverConfig::new();
@@ -53,21 +70,38 @@ pub async fn create_resolver(dns1_sock: SocketAddr, dns2_sock: SocketAddr) -> To

pub async fn resolve_with_cache(host: &str,
resolver: &TokioAsyncResolver,
resolver_cache: Arc<Mutex<LruCache<String, String>>>) -> String {
resolver_cache: Arc<Mutex<LruCache<String, String>>>) -> Result<String, DnsResolveError> {
let host_string = String::from(host);
let mut guard = resolver_cache.lock().await;
let found = (*guard).get(&host_string);

if found.is_none() {
trace!("resolve_with_cache: host={} not in cache, resolving...", String::from(host_string.as_str()));
let resolved_ip = format!("{}", resolver.lookup_ip(host).await.unwrap().iter().next().unwrap());
(*guard).put(String::from(host_string.as_str()), resolved_ip.to_string());
debug!("resolve_with_cache: resolved {} -> {}", String::from(host_string.as_str()), &resolved_ip);
resolved_ip
let lookup_result = resolver.lookup_ip(host).await;
if lookup_result.is_err() {
let err = lookup_result.err();
if err.is_some() {
Err(DnsResolveError::ResolutionFailure(err.unwrap()))
} else {
Err(DnsResolveError::DnsUnknownError)
}
} else {
let ip_result = lookup_result.unwrap();
let first_result = ip_result.iter().next();
if first_result.is_none() {
error!("resolve_with_cache: {} - no records found", String::from(host_string.as_str()));
Err(DnsResolveError::DnsNoRecordsFound)
} else {
let resolved_ip = format!("{}", first_result.unwrap());
(*guard).put(String::from(host_string.as_str()), resolved_ip.to_string());
debug!("resolve_with_cache: resolved {} -> {}", String::from(host_string.as_str()), &resolved_ip);
Ok(resolved_ip)
}
}
} else {
let found = found.unwrap();
trace!("resolve_with_cache: host={} found in cache, returning: {}", host_string, found);
String::from(found)
Ok(String::from(found))
}
}

@@ -92,24 +126,29 @@ pub struct CacheAddrs {
}

pub struct CacheFuture {
inner: JoinHandle<Result<IpAddrs, std::io::Error>>
inner: JoinHandle<Result<IpAddrs, DnsResolveError>>
}

pub async fn resolve_to_result(host: String,
resolver: Arc<TokioAsyncResolver>,
cache: Arc<Mutex<LruCache<String, String>>>) -> Result<IpAddrs, Error> {
let ip = resolve_with_cache(host.as_str(), &resolver, cache).await;
let ip_addr: IpAddr = ip.parse().unwrap();
let sock = SocketAddr::new(ip_addr, 0);
Ok(IpAddrs { iter: vec![sock].into_iter() })
cache: Arc<Mutex<LruCache<String, String>>>) -> Result<IpAddrs, DnsResolveError> {
let resolve_result = resolve_with_cache(host.as_str(), &resolver, cache).await;
if resolve_result.is_err() {
Err(resolve_result.err().unwrap())
} else {
let ip = resolve_result.unwrap();
let ip_addr: IpAddr = ip.parse().unwrap();
let sock = SocketAddr::new(ip_addr, 0);
Ok(IpAddrs { iter: vec![sock].into_iter() })
}
}

impl Service<Name> for CacheResolver {
type Response = CacheAddrs;
type Error = std::io::Error;
type Error = DnsResolveError;
type Future = CacheFuture;

fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), DnsResolveError>> {
Poll::Ready(Ok(()))
}

@@ -119,14 +158,13 @@ impl Service<Name> for CacheResolver {
let cache: Arc<Mutex<LruCache<String, String>>> = self._cache.clone();
let addrs = tokio::task::spawn(
resolve_to_result(String::from(name.as_str()), resolver, cache)
// resolve_with_cache(host.as_str(), &resolver, cache)
);
CacheFuture { inner: addrs }
}
}

impl Future for CacheFuture {
type Output = Result<CacheAddrs, std::io::Error>;
type Output = Result<CacheAddrs, DnsResolveError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx).map(|res| match res {
@@ -134,7 +172,7 @@ impl Future for CacheFuture {
Ok(Err(err)) => Err(err),
Err(join_err) => {
if join_err.is_cancelled() {
Err(std::io::Error::new(std::io::ErrorKind::Interrupted, join_err))
Err(DnsResolveError::InterruptedError(Error::new(ErrorKind::Interrupted, join_err)))
} else {
panic!("gai background task failed: {:?}", join_err)
}


+ 24
- 6
src/proxy.rs 查看文件

@@ -126,11 +126,23 @@ async fn proxy(client: Client<HttpsConnector<HttpConnector<CacheResolver>>>,
bad_request("(remove) invalid ping hash\n")
} else {
let routes = remove_routes.routes.clone();
let mut resolve_errors: Vec<(String, DnsResolveError)> = Vec::new();
for route in routes.into_iter() {
let ip_string = resolve_with_cache(route.as_str(), &resolver, resolver_cache.clone()).await;
remove_static_route(&ip_string);
let resolve_result = resolve_with_cache(route.as_str(), &resolver, resolver_cache.clone()).await;
if resolve_result.is_err() {
let err = resolve_result.err().unwrap();
error!("proxy(remove): error resolving hostname {:?}: {:?}", route.clone(), err);
resolve_errors.push((route.clone(), err));
} else {
let ip_string = resolve_result.unwrap();
remove_static_route(&ip_string);
}
}
if resolve_errors.is_empty() {
Ok(Response::new(Body::from(format!("Removed: {:?}", remove_routes.routes))))
} else {
bad_request(format!("(remove) resolution errors: {:?}\n", resolve_errors).as_str())
}
Ok(Response::new(Body::from(format!("Removed: {:?}", remove_routes.routes))))
}

} else if path.eq(PATH_HEALTH) && method == Method::GET {
@@ -143,9 +155,15 @@ async fn proxy(client: Client<HttpsConnector<HttpConnector<CacheResolver>>>,
}

let host = host.unwrap();
let host_string = String::from(host);
trace!("proxy: received request for host {:?}, resolving...", host_string);
let ip_string = resolve_with_cache(host, &resolver, resolver_cache).await;
let host_string = Arc::new(String::from(host));
trace!("proxy: received request for host {:?}, resolving...", host_string.clone());
let resolve_result = resolve_with_cache(host, &resolver, resolver_cache).await;
if resolve_result.is_err() {
let err = resolve_result.err().unwrap();
error!("proxy: error resolving hostname {:?}: {:?}", host_string.clone(), err);
return bad_request(format!("Error: error resolving hostname: {:?}: {:?}\n", host_string.clone(), err).as_str());
}
let ip_string = resolve_result.unwrap();
info!("proxy: host {} resolved to: {}", host, ip_string);
trace!("proxy: request is {:?}", req);



Loading…
取消
儲存