Files
WTFnet/crates/wtfnet-probe/src/lib.rs
2026-01-17 12:51:41 +08:00

855 lines
24 KiB
Rust

#[cfg(unix)]
use pnet::packet::icmp::{IcmpPacket, IcmpTypes};
#[cfg(unix)]
use pnet::packet::icmpv6::{Icmpv6Packet, Icmpv6Types};
#[cfg(unix)]
use pnet::packet::ip::IpNextHeaderProtocols;
#[cfg(unix)]
use pnet::transport::{
TransportChannelType, TransportProtocol, icmp_packet_iter, icmpv6_packet_iter,
transport_channel,
};
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
use serde::{Deserialize, Serialize};
use socket2::{Domain, Protocol, Socket, Type};
use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, SocketAddr};
#[cfg(unix)]
use std::mem::size_of_val;
use std::time::{Duration, Instant};
use hickory_resolver::config::{ResolverConfig, ResolverOpts};
use hickory_resolver::system_conf::read_system_conf;
use hickory_resolver::TokioAsyncResolver;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
use tokio::time::timeout;
use tokio_socks::tcp::Socks5Stream;
use tracing::debug;
use url::Url;
use wtfnet_geoip::GeoIpRecord;
#[derive(Debug, Error)]
pub enum ProbeError {
#[error("resolution failed: {0}")]
Resolve(String),
#[error("io error: {0}")]
Io(String),
#[error("invalid proxy: {0}")]
InvalidProxy(String),
#[error("proxy error: {0}")]
Proxy(String),
#[error("timeout")]
Timeout,
#[error("ping error: {0}")]
Ping(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PingResult {
pub seq: u16,
pub rtt_ms: Option<u128>,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PingSummary {
pub sent: u32,
pub received: u32,
pub loss_pct: f64,
pub min_ms: Option<u128>,
pub avg_ms: Option<f64>,
pub max_ms: Option<u128>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PingReport {
pub target: String,
pub ip: Option<String>,
pub geoip: Option<GeoIpRecord>,
pub timeout_ms: u64,
pub count: u32,
pub results: Vec<PingResult>,
pub summary: PingSummary,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TcpPingResult {
pub seq: u16,
pub rtt_ms: Option<u128>,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TcpPingReport {
pub target: String,
pub ip: Option<String>,
pub geoip: Option<GeoIpRecord>,
pub port: u16,
pub timeout_ms: u64,
pub count: u32,
pub results: Vec<TcpPingResult>,
pub summary: PingSummary,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceHop {
pub ttl: u8,
pub addr: Option<String>,
pub rtt_ms: Option<u128>,
pub rtt_samples: Vec<Option<u128>>,
pub min_ms: Option<u128>,
pub avg_ms: Option<f64>,
pub max_ms: Option<u128>,
pub loss_pct: f64,
pub rdns: Option<String>,
pub note: Option<String>,
pub geoip: Option<GeoIpRecord>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceReport {
pub target: String,
pub ip: Option<String>,
pub geoip: Option<GeoIpRecord>,
pub port: u16,
pub max_hops: u8,
pub timeout_ms: u64,
pub per_hop: u32,
pub rdns: bool,
pub protocol: String,
pub hops: Vec<TraceHop>,
}
pub async fn ping(
target: &str,
count: u32,
timeout_ms: u64,
interval_ms: u64,
) -> Result<PingReport, ProbeError> {
debug!(
target,
count,
timeout_ms,
interval_ms,
"probe ping start"
);
let addr = resolve_one(target).await?;
debug!(ip = %addr, "probe ping resolved");
let mut results = Vec::new();
let mut received = 0u32;
let mut min = None;
let mut max = None;
let mut sum = 0u128;
let config = match addr {
IpAddr::V4(_) => surge_ping::Config::default(),
IpAddr::V6(_) => surge_ping::Config::builder()
.kind(surge_ping::ICMP::V6)
.build(),
};
let client =
surge_ping::Client::new(&config).map_err(|err| ProbeError::Ping(err.to_string()))?;
let mut pinger = client.pinger(addr, surge_ping::PingIdentifier(0)).await;
let timeout_dur = Duration::from_millis(timeout_ms);
for seq in 0..count {
let seq = seq as u16;
let start = Instant::now();
let response = timeout(
timeout_dur,
pinger.ping(surge_ping::PingSequence(seq), &[0; 8]),
)
.await;
match response {
Ok(Ok((_packet, _))) => {
let rtt = start.elapsed().as_millis();
received += 1;
min = Some(min.map_or(rtt, |value: u128| value.min(rtt)));
max = Some(max.map_or(rtt, |value: u128| value.max(rtt)));
sum += rtt;
results.push(PingResult {
seq,
rtt_ms: Some(rtt),
error: None,
});
}
Ok(Err(err)) => {
results.push(PingResult {
seq,
rtt_ms: None,
error: Some(err.to_string()),
});
}
Err(_) => {
results.push(PingResult {
seq,
rtt_ms: None,
error: Some("timeout".to_string()),
});
}
}
if interval_ms > 0 {
tokio::time::sleep(Duration::from_millis(interval_ms)).await;
}
}
let summary = build_summary(count, received, min, max, sum);
Ok(PingReport {
target: target.to_string(),
ip: Some(addr.to_string()),
geoip: None,
timeout_ms,
count,
results,
summary,
})
}
pub async fn tcp_ping(
target: &str,
port: u16,
count: u32,
timeout_ms: u64,
proxy: Option<&str>,
prefer_ipv4: bool,
) -> Result<TcpPingReport, ProbeError> {
debug!(
target,
port,
count,
timeout_ms,
proxy = ?proxy,
prefer_ipv4,
"probe tcp ping start"
);
let (report_ip, target_host, proxy_addr) = if let Some(proxy) = proxy {
let proxy = parse_socks5_proxy(proxy)?;
if proxy.remote_dns {
(None, target.to_string(), proxy.addr)
} else {
let addr = if prefer_ipv4 {
resolve_one_prefer_ipv4(target).await?
} else {
resolve_one(target).await?
};
(Some(addr), addr.to_string(), proxy.addr)
}
} else {
let addr = if prefer_ipv4 {
resolve_one_prefer_ipv4(target).await?
} else {
resolve_one(target).await?
};
(Some(addr), addr.to_string(), String::new())
};
let socket_addr = report_ip.map(|addr| SocketAddr::new(addr, port));
debug!(
report_ip = ?report_ip,
target_host = %target_host,
proxy_addr = %proxy_addr,
"probe tcp ping resolved"
);
let timeout_dur = Duration::from_millis(timeout_ms);
let mut results = Vec::new();
let mut received = 0u32;
let mut min = None;
let mut max = None;
let mut sum = 0u128;
for seq in 0..count {
let seq = seq as u16;
let start = Instant::now();
let attempt: Result<TcpStream, ProbeError> = if proxy.is_some() {
let target = (target_host.as_str(), port);
let stream = timeout(
timeout_dur,
Socks5Stream::connect(proxy_addr.as_str(), target),
)
.await
.map_err(|_| ProbeError::Timeout)?
.map_err(|err| ProbeError::Proxy(err.to_string()))?;
Ok(stream.into_inner())
} else {
timeout(
timeout_dur,
TcpStream::connect(socket_addr.expect("missing socket addr")),
)
.await
.map_err(|_| ProbeError::Timeout)?
.map_err(|err| ProbeError::Io(err.to_string()))
};
match attempt {
Ok(_stream) => {
let rtt = start.elapsed().as_millis();
received += 1;
min = Some(min.map_or(rtt, |value: u128| value.min(rtt)));
max = Some(max.map_or(rtt, |value: u128| value.max(rtt)));
sum += rtt;
results.push(TcpPingResult {
seq,
rtt_ms: Some(rtt),
error: None,
});
}
Err(err) => {
results.push(TcpPingResult {
seq,
rtt_ms: None,
error: Some(err.to_string()),
});
}
}
}
let summary = build_summary(count, received, min, max, sum);
Ok(TcpPingReport {
target: target.to_string(),
ip: report_ip.map(|addr| addr.to_string()),
geoip: None,
port,
timeout_ms,
count,
results,
summary,
})
}
pub async fn tcp_trace(
target: &str,
port: u16,
max_hops: u8,
timeout_ms: u64,
per_hop: u32,
rdns: bool,
) -> Result<TraceReport, ProbeError> {
debug!(
target,
port,
max_hops,
timeout_ms,
"probe tcp trace start"
);
let addr = resolve_one(target).await?;
debug!(ip = %addr, "probe tcp trace resolved");
let socket_addr = SocketAddr::new(addr, port);
let timeout_dur = Duration::from_millis(timeout_ms);
let mut hops = Vec::new();
let mut rdns_lookup = if rdns {
Some(ReverseDns::new(timeout_dur)?)
} else {
None
};
for ttl in 1..=max_hops {
debug!(ttl, per_hop, "probe tcp trace hop start");
let mut samples = Vec::new();
let mut last_error = None;
for _ in 0..per_hop.max(1) {
let addr = socket_addr;
let start = Instant::now();
let result =
tokio::task::spawn_blocking(move || tcp_connect_with_ttl(addr, ttl, timeout_dur))
.await
.map_err(|err| ProbeError::Io(err.to_string()))?;
match result {
Ok(()) => {
let rtt = start.elapsed().as_millis();
debug!(ttl, rtt_ms = rtt, "probe tcp trace hop reply");
samples.push(Some(rtt));
}
Err(err) => {
let message = err.to_string();
debug!(ttl, error = %message, "probe tcp trace hop error");
last_error = Some(message);
samples.push(None);
}
}
}
let (min_ms, avg_ms, max_ms, loss_pct) = stats_from_samples(&samples);
let rtt_ms = avg_ms.map(|value| value.round() as u128);
let rdns_name = if rdns {
if let Some(lookup) = rdns_lookup.as_mut() {
lookup.lookup(socket_addr.ip()).await
} else {
None
}
} else {
None
};
let note = if loss_pct >= 100.0 {
last_error
} else {
None
};
hops.push(TraceHop {
ttl,
addr: Some(socket_addr.ip().to_string()),
rtt_ms,
rtt_samples: samples,
min_ms,
avg_ms,
max_ms,
loss_pct,
rdns: rdns_name,
note,
geoip: None,
});
debug!(
ttl,
loss_pct,
min_ms = ?min_ms,
avg_ms = ?avg_ms,
max_ms = ?max_ms,
"probe tcp trace hop summary"
);
if loss_pct < 100.0 {
break;
}
}
Ok(TraceReport {
target: target.to_string(),
ip: Some(addr.to_string()),
geoip: None,
port,
max_hops,
timeout_ms,
per_hop,
rdns,
protocol: "tcp".to_string(),
hops,
})
}
pub async fn udp_trace(
target: &str,
port: u16,
max_hops: u8,
timeout_ms: u64,
per_hop: u32,
rdns: bool,
) -> Result<TraceReport, ProbeError> {
debug!(
target,
port,
max_hops,
timeout_ms,
"probe udp trace start"
);
let addr = resolve_one(target).await?;
debug!(ip = %addr, "probe udp trace resolved");
let timeout_dur = Duration::from_millis(timeout_ms);
let mut hops = Vec::new();
let mut rdns_lookup = if rdns {
Some(ReverseDns::new(timeout_dur)?)
} else {
None
};
for ttl in 1..=max_hops {
debug!(ttl, per_hop, "probe udp trace hop start");
let mut samples = Vec::new();
let mut hop_addr = None;
let mut reached_any = false;
let mut last_error = None;
let mut addr_set = HashSet::new();
for _ in 0..per_hop.max(1) {
let addr = SocketAddr::new(addr, port);
let start = Instant::now();
let result = tokio::task::spawn_blocking(move || udp_trace_hop(addr, ttl, timeout_dur))
.await
.map_err(|err| ProbeError::Io(err.to_string()))?;
match result {
Ok((addr, reached)) => {
let rtt = start.elapsed().as_millis();
debug!(
ttl,
addr = ?addr,
rtt_ms = rtt,
reached,
"probe udp trace hop reply"
);
samples.push(Some(rtt));
if let Some(ip) = addr {
addr_set.insert(ip);
if hop_addr.is_none() {
hop_addr = Some(ip);
}
}
if reached {
reached_any = true;
}
}
Err(err) => {
let message = err.to_string();
debug!(ttl, error = %message, "probe udp trace hop error");
last_error = Some(message);
samples.push(None);
}
}
}
let (min_ms, avg_ms, max_ms, loss_pct) = stats_from_samples(&samples);
let rtt_ms = avg_ms.map(|value| value.round() as u128);
let rdns_name = if rdns {
if let (Some(ip), Some(lookup)) = (hop_addr, rdns_lookup.as_mut()) {
lookup.lookup(ip).await
} else {
None
}
} else {
None
};
let note = if loss_pct >= 100.0 {
last_error
} else if addr_set.len() > 1 {
Some("multiple hop addresses".to_string())
} else {
None
};
hops.push(TraceHop {
ttl,
addr: hop_addr.map(|ip| ip.to_string()),
rtt_ms,
rtt_samples: samples,
min_ms,
avg_ms,
max_ms,
loss_pct,
rdns: rdns_name,
note,
geoip: None,
});
debug!(
ttl,
loss_pct,
min_ms = ?min_ms,
avg_ms = ?avg_ms,
max_ms = ?max_ms,
reached_any,
"probe udp trace hop summary"
);
if reached_any {
break;
}
}
Ok(TraceReport {
target: target.to_string(),
ip: Some(addr.to_string()),
geoip: None,
port,
max_hops,
timeout_ms,
per_hop,
rdns,
protocol: "udp".to_string(),
hops,
})
}
fn build_summary(
sent: u32,
received: u32,
min: Option<u128>,
max: Option<u128>,
sum: u128,
) -> PingSummary {
let loss_pct = if sent == 0 {
0.0
} else {
((sent - received) as f64 / sent as f64) * 100.0
};
let avg_ms = if received == 0 {
None
} else {
Some(sum as f64 / received as f64)
};
PingSummary {
sent,
received,
loss_pct,
min_ms: min,
avg_ms,
max_ms: max,
}
}
fn stats_from_samples(
samples: &[Option<u128>],
) -> (Option<u128>, Option<f64>, Option<u128>, f64) {
let mut min = None;
let mut max = None;
let mut sum = 0u128;
let mut received = 0u32;
for sample in samples {
if let Some(rtt) = sample {
received += 1;
min = Some(min.map_or(*rtt, |value: u128| value.min(*rtt)));
max = Some(max.map_or(*rtt, |value: u128| value.max(*rtt)));
sum += *rtt;
}
}
let sent = samples.len() as u32;
let loss_pct = if sent == 0 {
0.0
} else {
((sent - received) as f64 / sent as f64) * 100.0
};
let avg_ms = if received == 0 {
None
} else {
Some(sum as f64 / received as f64)
};
(min, avg_ms, max, loss_pct)
}
async fn resolve_one(target: &str) -> Result<IpAddr, ProbeError> {
let mut iter = lookup_host((target, 0))
.await
.map_err(|err| ProbeError::Resolve(err.to_string()))?;
iter.next()
.map(|addr| addr.ip())
.ok_or_else(|| ProbeError::Resolve("no address found".to_string()))
}
async fn resolve_one_prefer_ipv4(target: &str) -> Result<IpAddr, ProbeError> {
let mut iter = lookup_host((target, 0))
.await
.map_err(|err| ProbeError::Resolve(err.to_string()))?;
let mut fallback = None;
for addr in iter.by_ref() {
if addr.ip().is_ipv4() {
return Ok(addr.ip());
}
if fallback.is_none() {
fallback = Some(addr.ip());
}
}
fallback.ok_or_else(|| ProbeError::Resolve("no address found".to_string()))
}
struct ReverseDns {
resolver: TokioAsyncResolver,
cache: HashMap<IpAddr, Option<String>>,
timeout: Duration,
}
impl ReverseDns {
fn new(timeout: Duration) -> Result<Self, ProbeError> {
let (config, opts) = match read_system_conf() {
Ok((config, opts)) => (config, opts),
Err(_) => (ResolverConfig::default(), ResolverOpts::default()),
};
let resolver = TokioAsyncResolver::tokio(config, opts);
Ok(Self {
resolver,
cache: HashMap::new(),
timeout,
})
}
async fn lookup(&mut self, ip: IpAddr) -> Option<String> {
if let Some(value) = self.cache.get(&ip) {
return value.clone();
}
let result = timeout(self.timeout, self.resolver.reverse_lookup(ip)).await;
let value = match result {
Ok(Ok(response)) => response.iter().next().map(|name| name.to_utf8()),
_ => None,
};
self.cache.insert(ip, value.clone());
value
}
}
struct Socks5Proxy {
addr: String,
remote_dns: bool,
}
fn parse_socks5_proxy(value: &str) -> Result<Socks5Proxy, ProbeError> {
let url = Url::parse(value).map_err(|_| ProbeError::InvalidProxy(value.to_string()))?;
let scheme = url.scheme();
let remote_dns = match scheme {
"socks5" => false,
"socks5h" => true,
_ => return Err(ProbeError::InvalidProxy(value.to_string())),
};
if !url.username().is_empty() || url.password().is_some() {
return Err(ProbeError::Proxy("proxy auth not supported".to_string()));
}
let host = url
.host_str()
.ok_or_else(|| ProbeError::InvalidProxy(value.to_string()))?;
let port = url
.port_or_known_default()
.ok_or_else(|| ProbeError::InvalidProxy(value.to_string()))?;
Ok(Socks5Proxy {
addr: format!("{host}:{port}"),
remote_dns,
})
}
fn tcp_connect_with_ttl(addr: SocketAddr, ttl: u8, timeout: Duration) -> Result<(), ProbeError> {
let domain = match addr.ip() {
IpAddr::V4(_) => Domain::IPV4,
IpAddr::V6(_) => Domain::IPV6,
};
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
.map_err(|err| ProbeError::Io(err.to_string()))?;
match addr.ip() {
IpAddr::V4(_) => socket
.set_ttl_v4(u32::from(ttl))
.map_err(|err| ProbeError::Io(err.to_string()))?,
IpAddr::V6(_) => socket
.set_unicast_hops_v6(u32::from(ttl))
.map_err(|err| ProbeError::Io(err.to_string()))?,
}
socket
.connect_timeout(&addr.into(), timeout)
.map_err(|err| ProbeError::Io(err.to_string()))?;
Ok(())
}
#[cfg(unix)]
fn udp_trace_hop(
addr: SocketAddr,
ttl: u8,
timeout: Duration,
) -> Result<(Option<IpAddr>, bool), ProbeError> {
match addr.ip() {
IpAddr::V4(_) => udp_trace_hop_v4(addr, ttl, timeout),
IpAddr::V6(_) => udp_trace_hop_v6(addr, ttl, timeout),
}
}
#[cfg(not(unix))]
fn udp_trace_hop(
_addr: SocketAddr,
_ttl: u8,
_timeout: Duration,
) -> Result<(Option<IpAddr>, bool), ProbeError> {
Err(ProbeError::Io(
"udp trace not supported on this platform".to_string(),
))
}
#[cfg(unix)]
fn udp_trace_hop_v4(
addr: SocketAddr,
ttl: u8,
timeout: Duration,
) -> Result<(Option<IpAddr>, bool), ProbeError> {
let protocol =
TransportChannelType::Layer4(TransportProtocol::Ipv4(IpNextHeaderProtocols::Icmp));
let (_tx, mut rx) =
transport_channel(4096, protocol).map_err(|err| ProbeError::Io(err.to_string()))?;
let socket =
std::net::UdpSocket::bind("0.0.0.0:0").map_err(|err| ProbeError::Io(err.to_string()))?;
socket
.set_ttl(u32::from(ttl))
.map_err(|err| ProbeError::Io(err.to_string()))?;
let _ = socket.send_to(&[0u8; 4], addr);
let mut iter = icmp_packet_iter(&mut rx);
match iter.next_with_timeout(timeout) {
Ok(Some((packet, addr))) => {
if let Some(result) = interpret_icmp_v4(&packet) {
return Ok((Some(addr), result));
}
Ok((Some(addr), false))
}
Ok(None) => Err(ProbeError::Timeout),
Err(err) => Err(ProbeError::Io(err.to_string())),
}
}
#[cfg(unix)]
fn udp_trace_hop_v6(
addr: SocketAddr,
ttl: u8,
timeout: Duration,
) -> Result<(Option<IpAddr>, bool), ProbeError> {
let protocol =
TransportChannelType::Layer4(TransportProtocol::Ipv6(IpNextHeaderProtocols::Icmpv6));
let (_tx, mut rx) =
transport_channel(4096, protocol).map_err(|err| ProbeError::Io(err.to_string()))?;
let socket =
std::net::UdpSocket::bind("[::]:0").map_err(|err| ProbeError::Io(err.to_string()))?;
set_ipv6_unicast_hops(&socket, ttl)?;
let _ = socket.send_to(&[0u8; 4], addr);
let mut iter = icmpv6_packet_iter(&mut rx);
match iter.next_with_timeout(timeout) {
Ok(Some((packet, addr))) => {
if let Some(result) = interpret_icmp_v6(&packet) {
return Ok((Some(addr), result));
}
Ok((Some(addr), false))
}
Ok(None) => Err(ProbeError::Timeout),
Err(err) => Err(ProbeError::Io(err.to_string())),
}
}
#[cfg(unix)]
fn set_ipv6_unicast_hops(socket: &std::net::UdpSocket, ttl: u8) -> Result<(), ProbeError> {
let fd = socket.as_raw_fd();
let hops: libc::c_int = ttl.into();
let result = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_IPV6,
libc::IPV6_UNICAST_HOPS,
&hops as *const _ as *const libc::c_void,
size_of_val(&hops) as libc::socklen_t,
)
};
if result == 0 {
Ok(())
} else {
Err(ProbeError::Io(
std::io::Error::last_os_error().to_string(),
))
}
}
#[cfg(unix)]
fn interpret_icmp_v4(packet: &IcmpPacket) -> Option<bool> {
let icmp_type = packet.get_icmp_type();
if icmp_type == IcmpTypes::TimeExceeded {
return Some(false);
}
if icmp_type == IcmpTypes::DestinationUnreachable {
return Some(true);
}
None
}
#[cfg(unix)]
fn interpret_icmp_v6(packet: &Icmpv6Packet) -> Option<bool> {
let icmp_type = packet.get_icmpv6_type();
if icmp_type == Icmpv6Types::TimeExceeded {
return Some(false);
}
if icmp_type == Icmpv6Types::DestinationUnreachable {
return Some(true);
}
None
}