瀏覽代碼

log bad requests, add copyright/license for hyper example code

master
Jonathan Cobb 4 年之前
父節點
當前提交
57f0c96855
共有 2 個文件被更改,包括 20 次插入9 次删除
  1. +6
    -0
      src/main.rs
  2. +14
    -9
      src/proxy.rs

+ 6
- 0
src/main.rs 查看文件

@@ -4,6 +4,12 @@
* For personal (non-commercial) use, see license: https://getbubblenow.com/bubble-license/
*/

/**
* This code was adapted from https://github.com/hyperium/hyper/blob/master/examples/http_proxy.rs
* Copyright (c) 2014-2018 Sean McArthur
* License: https://raw.githubusercontent.com/hyperium/hyper/master/LICENSE
*/

use std::process::exit;

use clap::{Arg, ArgMatches, App};


+ 14
- 9
src/proxy.rs 查看文件

@@ -4,6 +4,12 @@
* For personal (non-commercial) use, see license: https://getbubblenow.com/bubble-license/
*/

/**
* This code was adapted from https://github.com/hyperium/hyper/blob/master/examples/http_proxy.rs
* Copyright (c) 2014-2018 Sean McArthur
* License: https://raw.githubusercontent.com/hyperium/hyper/master/LICENSE
*/

extern crate lru;

use std::convert::Infallible;
@@ -79,14 +85,16 @@ async fn proxy(client: Client<HttpsConnector<HttpConnector<CacheResolver>>>,
req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
let host = req.uri().host();
if host.is_none() {
eprintln!("proxy: ERROR: no host, returning 400");
return bad_request("No host!");
}
let host = host.unwrap();
let ip_string = resolve_with_cache(host, &resolver, resolver_cache).await;
eprintln!("req(host {} resolved to: {}): {:?}", host, ip_string, req);
eprintln!("proxy: req(host {} resolved to: {}): {:?}", host, ip_string, req);

if needs_static_route(&ip_string) {
if !create_static_route(&gateway, &ip_string) {
eprintln!("proxy: ERROR: error creating static route to {:?}", ip_string);
return bad_request(format!("Error: error creating static route to {:?}", ip_string).as_str());
}
}
@@ -110,16 +118,16 @@ async fn proxy(client: Client<HttpsConnector<HttpConnector<CacheResolver>>>,
match req.into_body().on_upgrade().await {
Ok(upgraded) => {
if let Err(e) = tunnel(upgraded, addr).await {
eprintln!("server io error: {}", e);
eprintln!("proxy: ERROR: server io error: {}", e);
};
}
Err(e) => eprintln!("upgrade error: {}", e),
Err(e) => eprintln!("proxy: ERROR: upgrade error: {}", e),
}
});

Ok(Response::new(Body::empty()))
} else {
eprintln!(">>> CONNECT host is not socket addr: {:?}", req.uri());
eprintln!("proxy: ERROR: CONNECT host is not socket addr: {:?}", req.uri());
return bad_request("CONNECT must be to a socket address");
}
} else {
@@ -152,13 +160,10 @@ async fn tunnel(upgraded: Upgraded, addr: SocketAddr) -> std::io::Result<()> {
// Print message when done
match amounts {
Ok((from_client, from_server)) => {
println!(
"client wrote {} bytes and received {} bytes",
from_client, from_server
);
println!("client wrote {} bytes and received {} bytes", from_client, from_server);
}
Err(e) => {
eprintln!("tunnel error: {}", e);
eprintln!("proxy: ERROR: tunnel error: {}", e);
}
};
Ok(())


Loading…
取消
儲存