Add: socks5 support. It may have problems with DoT, will see.

This commit is contained in:
DaZuo0122
2026-01-16 23:59:02 +08:00
parent edd1779920
commit 7746511fc4
12 changed files with 489 additions and 50 deletions

View File

@@ -226,6 +226,10 @@ struct ProbeTcpingArgs {
#[arg(long, default_value_t = 800)]
timeout_ms: u64,
#[arg(long)]
socks5: Option<String>,
#[arg(long)]
prefer_ipv4: bool,
#[arg(long)]
no_geoip: bool,
}
@@ -256,6 +260,8 @@ struct DnsQueryArgs {
tls_name: Option<String>,
#[arg(long)]
socks5: Option<String>,
#[arg(long)]
prefer_ipv4: bool,
#[arg(long, default_value_t = 2000)]
timeout_ms: u64,
}
@@ -271,6 +277,8 @@ struct DnsDetectArgs {
tls_name: Option<String>,
#[arg(long)]
socks5: Option<String>,
#[arg(long)]
prefer_ipv4: bool,
#[arg(long, default_value_t = 3)]
repeat: u32,
#[arg(long, default_value_t = 2000)]
@@ -328,6 +336,8 @@ struct HttpRequestArgs {
http2_only: bool,
#[arg(long)]
geoip: bool,
#[arg(long)]
socks5: Option<String>,
}
#[derive(Parser, Debug, Clone)]
@@ -341,6 +351,10 @@ struct TlsArgs {
timeout_ms: u64,
#[arg(long)]
insecure: bool,
#[arg(long)]
socks5: Option<String>,
#[arg(long)]
prefer_ipv4: bool,
}
#[derive(Parser, Debug, Clone)]
@@ -988,14 +1002,31 @@ async fn handle_probe_tcping(cli: &Cli, args: ProbeTcpingArgs) -> i32 {
}
};
match wtfnet_probe::tcp_ping(&host, port, args.count, args.timeout_ms).await {
match wtfnet_probe::tcp_ping(
&host,
port,
args.count,
args.timeout_ms,
args.socks5.as_deref(),
args.prefer_ipv4,
)
.await
{
Ok(mut report) => {
if !args.no_geoip {
enrich_tcp_geoip(&mut report);
}
if cli.json {
let meta = Meta::new("wtfnet", env!("CARGO_PKG_VERSION"), false);
let command = CommandInfo::new("probe tcping", vec![args.target]);
let mut command_args = vec![args.target];
if let Some(proxy) = args.socks5 {
command_args.push("--socks5".to_string());
command_args.push(proxy);
}
if args.prefer_ipv4 {
command_args.push("--prefer-ipv4".to_string());
}
let command = CommandInfo::new("probe tcping", command_args);
let envelope = CommandEnvelope::new(meta, command, report);
emit_json(cli, &envelope)
} else {
@@ -1280,7 +1311,12 @@ async fn handle_dns_query(cli: &Cli, args: DnsQueryArgs) -> i32 {
}
};
let server = match args.server.as_deref() {
Some(value) => match parse_dns_server_target(value, transport, args.tls_name.as_deref()) {
Some(value) => match parse_dns_server_target(
value,
transport,
args.tls_name.as_deref(),
args.prefer_ipv4,
) {
Ok(addr) => Some(addr),
Err(err) => {
eprintln!("{err}");
@@ -1360,7 +1396,14 @@ async fn handle_dns_detect(cli: &Cli, args: DnsDetectArgs) -> i32 {
let parsed = raw
.split(',')
.filter(|value| !value.trim().is_empty())
.map(|value| parse_dns_server_target(value.trim(), transport, args.tls_name.as_deref()))
.map(|value| {
parse_dns_server_target(
value.trim(),
transport,
args.tls_name.as_deref(),
args.prefer_ipv4,
)
})
.collect::<Result<Vec<_>, _>>();
match parsed {
Ok(values) => values,
@@ -1617,6 +1660,7 @@ fn parse_dns_server_target(
value: &str,
transport: wtfnet_dns::DnsTransport,
tls_name: Option<&str>,
prefer_ipv4: bool,
) -> Result<wtfnet_dns::DnsServerTarget, String> {
let default_port = match transport {
wtfnet_dns::DnsTransport::Udp | wtfnet_dns::DnsTransport::Tcp => 53,
@@ -1638,10 +1682,8 @@ fn parse_dns_server_target(
}
let (host, port) = split_host_port_with_default(value, default_port)?;
let addr = format!("{host}:{port}")
.to_socket_addrs()
let addr = resolve_host_port(&host, port, prefer_ipv4)
.map_err(|_| format!("invalid server address: {value}"))?
.next()
.ok_or_else(|| format!("unable to resolve server: {value}"))?;
let name = tls_name
@@ -1687,6 +1729,28 @@ fn split_host_port_with_default(value: &str, default_port: u16) -> Result<(Strin
Ok((value.to_string(), default_port))
}
fn resolve_host_port(
host: &str,
port: u16,
prefer_ipv4: bool,
) -> Result<Option<std::net::SocketAddr>, std::io::Error> {
let mut iter = (host, port).to_socket_addrs()?;
if prefer_ipv4 {
let mut fallback = None;
for addr in iter.by_ref() {
if addr.is_ipv4() {
return Ok(Some(addr));
}
if fallback.is_none() {
fallback = Some(addr);
}
}
Ok(fallback)
} else {
Ok(iter.next())
}
}
async fn handle_http_request(
cli: &Cli,
args: HttpRequestArgs,
@@ -1701,6 +1765,7 @@ async fn handle_http_request(
show_body: args.show_body,
http1_only: args.http1_only,
http2_only: args.http2_only,
proxy: args.socks5.clone(),
};
match wtfnet_http::request(&args.url, opts).await {
@@ -1850,6 +1915,8 @@ fn build_tls_options(args: &TlsArgs) -> wtfnet_tls::TlsOptions {
alpn: parse_alpn(args.alpn.as_deref()),
timeout_ms: args.timeout_ms,
insecure: args.insecure,
socks5: args.socks5.clone(),
prefer_ipv4: args.prefer_ipv4,
}
}

View File

@@ -7,9 +7,14 @@ edition = "2024"
hickory-resolver = { version = "0.24", features = ["dns-over-tls", "dns-over-https", "dns-over-https-rustls", "dns-over-rustls", "native-certs"] }
hickory-proto = "0.24"
reqwest = { version = "0.11", features = ["rustls-tls", "socks"] }
rustls = "0.21"
rustls-native-certs = "0.6"
serde = { version = "1", features = ["derive"] }
thiserror = "2"
tokio = { version = "1", features = ["time"] }
tokio = { version = "1", features = ["io-util", "time"] }
tokio-rustls = "0.24"
tokio-socks = "0.5"
url = "2"
pnet = { version = "0.34", optional = true }
[features]

View File

@@ -8,12 +8,18 @@ use hickory_resolver::system_conf::read_system_conf;
use hickory_proto::op::{Message, MessageType, Query};
use hickory_proto::rr::Name;
use reqwest::Proxy;
use rustls::{Certificate, ClientConfig, RootCertStore, ServerName};
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_rustls::TlsConnector;
use tokio_socks::tcp::Socks5Stream;
use url::Url;
#[cfg(feature = "pcap")]
use pnet::datalink::{self, Channel, Config as DatalinkConfig, NetworkInterface};
@@ -164,11 +170,16 @@ pub async fn query(
) -> Result<DnsQueryReport, DnsError> {
let record_type = parse_record_type(record_type)?;
if let Some(proxy) = proxy {
if transport != DnsTransport::Doh {
return Err(DnsError::ProxyUnsupported(transport.to_string()));
}
let server = server.ok_or_else(|| DnsError::MissingServer(transport.to_string()))?;
return doh_query_via_proxy(domain, record_type, server, timeout_ms, proxy).await;
return match transport {
DnsTransport::Doh => {
doh_query_via_proxy(domain, record_type, server, timeout_ms, proxy).await
}
DnsTransport::Dot => {
dot_query_via_proxy(domain, record_type, server, timeout_ms, proxy).await
}
_ => Err(DnsError::ProxyUnsupported(transport.to_string())),
};
}
let resolver = build_resolver(server.clone(), transport, timeout_ms)?;
let start = Instant::now();
@@ -512,6 +523,151 @@ async fn doh_query_via_proxy(
})
}
async fn dot_query_via_proxy(
domain: &str,
record_type: RecordType,
server: DnsServerTarget,
timeout_ms: u64,
proxy: String,
) -> Result<DnsQueryReport, DnsError> {
let tls_name = server
.name
.clone()
.ok_or_else(|| DnsError::MissingTlsName("dot".to_string()))?;
let name = Name::from_ascii(domain)
.map_err(|err| DnsError::Resolver(format!("invalid domain: {err}")))?;
let mut message = Message::new();
message
.set_id(0)
.set_message_type(MessageType::Query)
.set_recursion_desired(true)
.add_query(Query::query(name, record_type));
let body = message
.to_vec()
.map_err(|err| DnsError::Resolver(err.to_string()))?;
if body.len() > u16::MAX as usize {
return Err(DnsError::Resolver("dns message too large".to_string()));
}
let connector = build_tls_connector()?;
let proxy_config = parse_socks5_proxy(&proxy)?;
let target = if proxy_config.remote_dns {
(tls_name.clone(), server.addr.port())
} else {
(server.addr.ip().to_string(), server.addr.port())
};
let timeout = Duration::from_millis(timeout_ms);
let tcp = tokio::time::timeout(
timeout,
Socks5Stream::connect(proxy_config.addr.as_str(), target),
)
.await
.map_err(|_| DnsError::Resolver("timeout".to_string()))?
.map_err(|err| DnsError::Proxy(err.to_string()))?
.into_inner();
let server_name = ServerName::try_from(tls_name.as_str())
.map_err(|_| DnsError::MissingTlsName(tls_name.clone()))?;
let mut stream = tokio::time::timeout(timeout, connector.connect(server_name, tcp))
.await
.map_err(|_| DnsError::Resolver("timeout".to_string()))?
.map_err(|err| DnsError::Resolver(err.to_string()))?;
let start = Instant::now();
let response_bytes = tokio::time::timeout(timeout, async {
let length = (body.len() as u16).to_be_bytes();
stream.write_all(&length).await?;
stream.write_all(&body).await?;
stream.flush().await?;
let mut len_buf = [0u8; 2];
stream.read_exact(&mut len_buf).await?;
let response_len = u16::from_be_bytes(len_buf) as usize;
let mut response = vec![0u8; response_len];
stream.read_exact(&mut response).await?;
Ok::<Vec<u8>, std::io::Error>(response)
})
.await
.map_err(|_| DnsError::Resolver("timeout".to_string()))?
.map_err(|err| DnsError::Resolver(err.to_string()))?;
let response =
Message::from_vec(&response_bytes).map_err(|err| DnsError::Resolver(err.to_string()))?;
let duration_ms = start.elapsed().as_millis();
let mut answers = Vec::new();
for record in response.answers() {
let ttl = record.ttl();
let name = record.name().to_string();
let record_type = record.record_type().to_string();
if let Some(data) = record.data() {
if let Some(data) = format_rdata(data) {
answers.push(DnsAnswer {
name,
record_type,
ttl,
data,
});
}
}
}
Ok(DnsQueryReport {
domain: domain.to_string(),
record_type: record_type.to_string(),
transport: DnsTransport::Dot.to_string(),
server: Some(server.addr.to_string()),
server_name: Some(tls_name),
proxy: Some(proxy),
rcode: response.response_code().to_string(),
answers,
duration_ms,
})
}
fn build_tls_connector() -> Result<TlsConnector, DnsError> {
let mut roots = RootCertStore::empty();
let store = rustls_native_certs::load_native_certs()
.map_err(|err| DnsError::Io(err.to_string()))?;
for cert in store {
roots
.add(&Certificate(cert.0))
.map_err(|err| DnsError::Resolver(err.to_string()))?;
}
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
Ok(TlsConnector::from(Arc::new(config)))
}
struct Socks5Proxy {
addr: String,
remote_dns: bool,
}
fn parse_socks5_proxy(value: &str) -> Result<Socks5Proxy, DnsError> {
let url = Url::parse(value).map_err(|_| DnsError::Proxy(value.to_string()))?;
let scheme = url.scheme();
let remote_dns = match scheme {
"socks5" => false,
"socks5h" => true,
_ => return Err(DnsError::ProxyUnsupported(scheme.to_string())),
};
if !url.username().is_empty() || url.password().is_some() {
return Err(DnsError::Proxy("proxy auth not supported".to_string()));
}
let host = url
.host_str()
.ok_or_else(|| DnsError::Proxy(value.to_string()))?;
let port = url
.port_or_known_default()
.ok_or_else(|| DnsError::Proxy(value.to_string()))?;
Ok(Socks5Proxy {
addr: format!("{host}:{port}"),
remote_dns,
})
}
#[cfg(feature = "pcap")]
fn select_interface(name: Option<&str>) -> Option<NetworkInterface> {
let interfaces = datalink::interfaces();

View File

@@ -1,4 +1,4 @@
use reqwest::{Client, Method, StatusCode};
use reqwest::{Client, Method, Proxy, StatusCode};
use serde::{Deserialize, Serialize};
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};
@@ -63,6 +63,7 @@ pub struct HttpRequestOptions {
pub show_body: bool,
pub http1_only: bool,
pub http2_only: bool,
pub proxy: Option<String>,
}
pub async fn request(url: &str, opts: HttpRequestOptions) -> Result<HttpReport, HttpError> {
@@ -100,6 +101,11 @@ pub async fn request(url: &str, opts: HttpRequestOptions) -> Result<HttpReport,
builder.redirect(reqwest::redirect::Policy::none())
};
if let Some(proxy) = opts.proxy.as_ref() {
let proxy = Proxy::all(proxy).map_err(|err| HttpError::Request(err.to_string()))?;
builder = builder.proxy(proxy);
}
if opts.http1_only {
builder = builder.http1_only();
}

View File

@@ -12,3 +12,5 @@ tokio = { version = "1", features = ["net", "time"] }
surge-ping = "0.8"
wtfnet-geoip = { path = "../wtfnet-geoip" }
libc = "0.2"
tokio-socks = "0.5"
url = "2"

View File

@@ -20,6 +20,8 @@ use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
use tokio::time::timeout;
use tokio_socks::tcp::Socks5Stream;
use url::Url;
use wtfnet_geoip::GeoIpRecord;
#[derive(Debug, Error)]
@@ -28,6 +30,10 @@ pub enum ProbeError {
Resolve(String),
#[error("io error: {0}")]
Io(String),
#[error("invalid proxy: {0}")]
InvalidProxy(String),
#[error("proxy error: {0}")]
Proxy(String),
#[error("timeout")]
Timeout,
#[error("ping error: {0}")]
@@ -184,9 +190,30 @@ pub async fn tcp_ping(
port: u16,
count: u32,
timeout_ms: u64,
proxy: Option<&str>,
prefer_ipv4: bool,
) -> Result<TcpPingReport, ProbeError> {
let addr = resolve_one(target).await?;
let socket_addr = SocketAddr::new(addr, port);
let (report_ip, target_host, proxy_addr) = if let Some(proxy) = proxy {
let proxy = parse_socks5_proxy(proxy)?;
if proxy.remote_dns {
(None, target.to_string(), proxy.addr)
} else {
let addr = if prefer_ipv4 {
resolve_one_prefer_ipv4(target).await?
} else {
resolve_one(target).await?
};
(Some(addr), addr.to_string(), proxy.addr)
}
} else {
let addr = if prefer_ipv4 {
resolve_one_prefer_ipv4(target).await?
} else {
resolve_one(target).await?
};
(Some(addr), addr.to_string(), String::new())
};
let socket_addr = report_ip.map(|addr| SocketAddr::new(addr, port));
let timeout_dur = Duration::from_millis(timeout_ms);
let mut results = Vec::new();
let mut received = 0u32;
@@ -197,9 +224,27 @@ pub async fn tcp_ping(
for seq in 0..count {
let seq = seq as u16;
let start = Instant::now();
let attempt = timeout(timeout_dur, TcpStream::connect(socket_addr)).await;
let attempt: Result<TcpStream, ProbeError> = if proxy.is_some() {
let target = (target_host.as_str(), port);
let stream = timeout(
timeout_dur,
Socks5Stream::connect(proxy_addr.as_str(), target),
)
.await
.map_err(|_| ProbeError::Timeout)?
.map_err(|err| ProbeError::Proxy(err.to_string()))?;
Ok(stream.into_inner())
} else {
timeout(
timeout_dur,
TcpStream::connect(socket_addr.expect("missing socket addr")),
)
.await
.map_err(|_| ProbeError::Timeout)?
.map_err(|err| ProbeError::Io(err.to_string()))
};
match attempt {
Ok(Ok(_stream)) => {
Ok(_stream) => {
let rtt = start.elapsed().as_millis();
received += 1;
min = Some(min.map_or(rtt, |value: u128| value.min(rtt)));
@@ -211,27 +256,20 @@ pub async fn tcp_ping(
error: None,
});
}
Ok(Err(err)) => {
Err(err) => {
results.push(TcpPingResult {
seq,
rtt_ms: None,
error: Some(err.to_string()),
});
}
Err(_) => {
results.push(TcpPingResult {
seq,
rtt_ms: None,
error: Some("timeout".to_string()),
});
}
}
}
let summary = build_summary(count, received, min, max, sum);
Ok(TcpPingReport {
target: target.to_string(),
ip: Some(addr.to_string()),
ip: report_ip.map(|addr| addr.to_string()),
geoip: None,
port,
timeout_ms,
@@ -389,6 +427,50 @@ async fn resolve_one(target: &str) -> Result<IpAddr, ProbeError> {
.ok_or_else(|| ProbeError::Resolve("no address found".to_string()))
}
async fn resolve_one_prefer_ipv4(target: &str) -> Result<IpAddr, ProbeError> {
let mut iter = lookup_host((target, 0))
.await
.map_err(|err| ProbeError::Resolve(err.to_string()))?;
let mut fallback = None;
for addr in iter.by_ref() {
if addr.ip().is_ipv4() {
return Ok(addr.ip());
}
if fallback.is_none() {
fallback = Some(addr.ip());
}
}
fallback.ok_or_else(|| ProbeError::Resolve("no address found".to_string()))
}
struct Socks5Proxy {
addr: String,
remote_dns: bool,
}
fn parse_socks5_proxy(value: &str) -> Result<Socks5Proxy, ProbeError> {
let url = Url::parse(value).map_err(|_| ProbeError::InvalidProxy(value.to_string()))?;
let scheme = url.scheme();
let remote_dns = match scheme {
"socks5" => false,
"socks5h" => true,
_ => return Err(ProbeError::InvalidProxy(value.to_string())),
};
if !url.username().is_empty() || url.password().is_some() {
return Err(ProbeError::Proxy("proxy auth not supported".to_string()));
}
let host = url
.host_str()
.ok_or_else(|| ProbeError::InvalidProxy(value.to_string()))?;
let port = url
.port_or_known_default()
.ok_or_else(|| ProbeError::InvalidProxy(value.to_string()))?;
Ok(Socks5Proxy {
addr: format!("{host}:{port}"),
remote_dns,
})
}
fn tcp_connect_with_ttl(addr: SocketAddr, ttl: u8, timeout: Duration) -> Result<(), ProbeError> {
let domain = match addr.ip() {
IpAddr::V4(_) => Domain::IPV4,

View File

@@ -11,3 +11,5 @@ thiserror = "2"
tokio = { version = "1", features = ["net", "time"] }
tokio-rustls = "0.24"
x509-parser = "0.16"
tokio-socks = "0.5"
url = "2"

View File

@@ -7,6 +7,8 @@ use thiserror::Error;
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_rustls::TlsConnector;
use tokio_socks::tcp::Socks5Stream;
use url::Url;
use x509_parser::prelude::{FromDer, X509Certificate};
#[derive(Debug, Error)]
@@ -78,12 +80,23 @@ pub struct TlsOptions {
pub alpn: Vec<String>,
pub timeout_ms: u64,
pub insecure: bool,
pub socks5: Option<String>,
pub prefer_ipv4: bool,
}
pub async fn handshake(target: &str, options: TlsOptions) -> Result<TlsHandshakeReport, TlsError> {
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
let (host, port, server_name) = parse_target(target, options.sni.as_deref())?;
let connector = build_connector(options.insecure, &options.alpn)?;
let stream = connect(addr, connector, server_name, options.timeout_ms).await?;
let stream = connect(
host.as_str(),
port,
options.socks5.as_deref(),
connector,
server_name,
options.timeout_ms,
options.prefer_ipv4,
)
.await?;
let (_, session) = stream.get_ref();
Ok(TlsHandshakeReport {
@@ -102,9 +115,19 @@ pub async fn handshake(target: &str, options: TlsOptions) -> Result<TlsHandshake
}
pub async fn verify(target: &str, options: TlsOptions) -> Result<TlsVerifyReport, TlsError> {
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
let (host, port, server_name) = parse_target(target, options.sni.as_deref())?;
let connector = build_connector(false, &options.alpn)?;
match connect(addr, connector, server_name, options.timeout_ms).await {
match connect(
host.as_str(),
port,
options.socks5.as_deref(),
connector,
server_name,
options.timeout_ms,
options.prefer_ipv4,
)
.await
{
Ok(stream) => {
let (_, session) = stream.get_ref();
Ok(TlsVerifyReport {
@@ -136,9 +159,18 @@ pub async fn verify(target: &str, options: TlsOptions) -> Result<TlsVerifyReport
}
pub async fn certs(target: &str, options: TlsOptions) -> Result<TlsCertReport, TlsError> {
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
let (host, port, server_name) = parse_target(target, options.sni.as_deref())?;
let connector = build_connector(options.insecure, &options.alpn)?;
let stream = connect(addr, connector, server_name, options.timeout_ms).await?;
let stream = connect(
host.as_str(),
port,
options.socks5.as_deref(),
connector,
server_name,
options.timeout_ms,
options.prefer_ipv4,
)
.await?;
let (_, session) = stream.get_ref();
Ok(TlsCertReport {
target: target.to_string(),
@@ -148,9 +180,18 @@ pub async fn certs(target: &str, options: TlsOptions) -> Result<TlsCertReport, T
}
pub async fn alpn(target: &str, options: TlsOptions) -> Result<TlsAlpnReport, TlsError> {
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
let (host, port, server_name) = parse_target(target, options.sni.as_deref())?;
let connector = build_connector(options.insecure, &options.alpn)?;
let stream = connect(addr, connector, server_name, options.timeout_ms).await?;
let stream = connect(
host.as_str(),
port,
options.socks5.as_deref(),
connector,
server_name,
options.timeout_ms,
options.prefer_ipv4,
)
.await?;
let (_, session) = stream.get_ref();
Ok(TlsAlpnReport {
target: target.to_string(),
@@ -162,9 +203,8 @@ pub async fn alpn(target: &str, options: TlsOptions) -> Result<TlsAlpnReport, Tl
})
}
fn parse_target(target: &str, sni: Option<&str>) -> Result<(SocketAddr, ServerName), TlsError> {
fn parse_target(target: &str, sni: Option<&str>) -> Result<(String, u16, ServerName), TlsError> {
let (host, port) = split_host_port(target)?;
let addr = resolve_addr(&host, port)?;
let server_name = if let Some(sni) = sni {
ServerName::try_from(sni).map_err(|_| TlsError::InvalidSni(sni.to_string()))?
} else if let Ok(ip) = host.parse::<IpAddr>() {
@@ -173,7 +213,7 @@ fn parse_target(target: &str, sni: Option<&str>) -> Result<(SocketAddr, ServerNa
ServerName::try_from(host.as_str())
.map_err(|_| TlsError::InvalidSni(host.to_string()))?
};
Ok((addr, server_name))
Ok((host, port, server_name))
}
fn split_host_port(value: &str) -> Result<(String, u16), TlsError> {
@@ -237,6 +277,24 @@ fn resolve_addr(host: &str, port: u16) -> Result<SocketAddr, TlsError> {
Ok(addr)
}
fn resolve_addr_prefer_ipv4(host: &str, port: u16) -> Result<SocketAddr, TlsError> {
if let Ok(ip) = host.parse::<IpAddr>() {
return Ok(SocketAddr::new(ip, port));
}
let mut iter = std::net::ToSocketAddrs::to_socket_addrs(&(host, port))
.map_err(|err| TlsError::Io(err.to_string()))?;
let mut fallback = None;
for addr in iter.by_ref() {
if addr.is_ipv4() {
return Ok(addr);
}
if fallback.is_none() {
fallback = Some(addr);
}
}
fallback.ok_or_else(|| TlsError::InvalidTarget(host.to_string()))
}
fn build_connector(insecure: bool, alpn: &[String]) -> Result<TlsConnector, TlsError> {
let mut config = if insecure {
ClientConfig::builder()
@@ -266,15 +324,46 @@ fn build_connector(insecure: bool, alpn: &[String]) -> Result<TlsConnector, TlsE
}
async fn connect(
addr: SocketAddr,
host: &str,
port: u16,
proxy: Option<&str>,
connector: TlsConnector,
server_name: ServerName,
timeout_ms: u64,
prefer_ipv4: bool,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, TlsError> {
let tcp = timeout(Duration::from_millis(timeout_ms), TcpStream::connect(addr))
let tcp = if let Some(proxy) = proxy {
let proxy_addr = parse_proxy_addr(proxy)?;
let (target_host, remote_dns) = socks5_target_host(proxy, host);
let target = if remote_dns {
(target_host.clone(), port)
} else {
let addr = if prefer_ipv4 {
resolve_addr_prefer_ipv4(target_host.as_str(), port)?
} else {
resolve_addr(target_host.as_str(), port)?
};
(addr.ip().to_string(), port)
};
let stream = timeout(
Duration::from_millis(timeout_ms),
Socks5Stream::connect(proxy_addr.as_str(), target),
)
.await
.map_err(|_| TlsError::Timeout)?
.map_err(|err| TlsError::Io(err.to_string()))?;
stream.into_inner()
} else {
let addr = if prefer_ipv4 {
resolve_addr_prefer_ipv4(host, port)?
} else {
resolve_addr(host, port)?
};
timeout(Duration::from_millis(timeout_ms), TcpStream::connect(addr))
.await
.map_err(|_| TlsError::Timeout)?
.map_err(|err| TlsError::Io(err.to_string()))?
};
let stream = timeout(
Duration::from_millis(timeout_ms),
connector.connect(server_name, tcp),
@@ -285,6 +374,22 @@ async fn connect(
Ok(stream)
}
fn parse_proxy_addr(value: &str) -> Result<String, TlsError> {
let url = Url::parse(value).map_err(|_| TlsError::InvalidTarget(value.to_string()))?;
let host = url
.host_str()
.ok_or_else(|| TlsError::InvalidTarget(value.to_string()))?;
let port = url
.port_or_known_default()
.ok_or_else(|| TlsError::InvalidTarget(value.to_string()))?;
Ok(format!("{host}:{port}"))
}
fn socks5_target_host(proxy: &str, host: &str) -> (String, bool) {
let remote_dns = proxy.starts_with("socks5h://");
(host.to_string(), remote_dns)
}
fn extract_cert_chain(certs: Option<&[Certificate]>) -> Result<Vec<TlsCertSummary>, TlsError> {
let mut results = Vec::new();
if let Some(certs) = certs {