Add multiple features

This commit is contained in:
DaZuo0122
2026-01-16 23:16:58 +08:00
parent c367ca29e4
commit cb022127c0
18 changed files with 1883 additions and 4 deletions

View File

@@ -0,0 +1,335 @@
use rustls::{Certificate, ClientConfig, RootCertStore, ServerName};
use serde::{Deserialize, Serialize};
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_rustls::TlsConnector;
use x509_parser::prelude::{FromDer, X509Certificate};
#[derive(Debug, Error)]
pub enum TlsError {
#[error("invalid target: {0}")]
InvalidTarget(String),
#[error("invalid sni: {0}")]
InvalidSni(String),
#[error("io error: {0}")]
Io(String),
#[error("tls error: {0}")]
Tls(String),
#[error("parse error: {0}")]
Parse(String),
#[error("timeout")]
Timeout,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsCertSummary {
pub subject: String,
pub issuer: String,
pub not_before: String,
pub not_after: String,
pub san: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsHandshakeReport {
pub target: String,
pub sni: Option<String>,
pub alpn_offered: Vec<String>,
pub alpn_negotiated: Option<String>,
pub tls_version: Option<String>,
pub cipher: Option<String>,
pub cert_chain: Vec<TlsCertSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsVerifyReport {
pub target: String,
pub sni: Option<String>,
pub alpn_offered: Vec<String>,
pub alpn_negotiated: Option<String>,
pub tls_version: Option<String>,
pub cipher: Option<String>,
pub verified: bool,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsCertReport {
pub target: String,
pub sni: Option<String>,
pub cert_chain: Vec<TlsCertSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsAlpnReport {
pub target: String,
pub sni: Option<String>,
pub alpn_offered: Vec<String>,
pub alpn_negotiated: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TlsOptions {
pub sni: Option<String>,
pub alpn: Vec<String>,
pub timeout_ms: u64,
pub insecure: bool,
}
pub async fn handshake(target: &str, options: TlsOptions) -> Result<TlsHandshakeReport, TlsError> {
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
let connector = build_connector(options.insecure, &options.alpn)?;
let stream = connect(addr, connector, server_name, options.timeout_ms).await?;
let (_, session) = stream.get_ref();
Ok(TlsHandshakeReport {
target: target.to_string(),
sni: options.sni,
alpn_offered: options.alpn.clone(),
alpn_negotiated: session
.alpn_protocol()
.map(|value| String::from_utf8_lossy(value).to_string()),
tls_version: session.protocol_version().map(|v| format!("{v:?}")),
cipher: session
.negotiated_cipher_suite()
.map(|suite| format!("{suite:?}")),
cert_chain: extract_cert_chain(session.peer_certificates())?,
})
}
pub async fn verify(target: &str, options: TlsOptions) -> Result<TlsVerifyReport, TlsError> {
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
let connector = build_connector(false, &options.alpn)?;
match connect(addr, connector, server_name, options.timeout_ms).await {
Ok(stream) => {
let (_, session) = stream.get_ref();
Ok(TlsVerifyReport {
target: target.to_string(),
sni: options.sni,
alpn_offered: options.alpn.clone(),
alpn_negotiated: session
.alpn_protocol()
.map(|value| String::from_utf8_lossy(value).to_string()),
tls_version: session.protocol_version().map(|v| format!("{v:?}")),
cipher: session
.negotiated_cipher_suite()
.map(|suite| format!("{suite:?}")),
verified: true,
error: None,
})
}
Err(err) => Ok(TlsVerifyReport {
target: target.to_string(),
sni: options.sni,
alpn_offered: options.alpn.clone(),
alpn_negotiated: None,
tls_version: None,
cipher: None,
verified: false,
error: Some(err.to_string()),
}),
}
}
pub async fn certs(target: &str, options: TlsOptions) -> Result<TlsCertReport, TlsError> {
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
let connector = build_connector(options.insecure, &options.alpn)?;
let stream = connect(addr, connector, server_name, options.timeout_ms).await?;
let (_, session) = stream.get_ref();
Ok(TlsCertReport {
target: target.to_string(),
sni: options.sni,
cert_chain: extract_cert_chain(session.peer_certificates())?,
})
}
pub async fn alpn(target: &str, options: TlsOptions) -> Result<TlsAlpnReport, TlsError> {
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
let connector = build_connector(options.insecure, &options.alpn)?;
let stream = connect(addr, connector, server_name, options.timeout_ms).await?;
let (_, session) = stream.get_ref();
Ok(TlsAlpnReport {
target: target.to_string(),
sni: options.sni,
alpn_offered: options.alpn.clone(),
alpn_negotiated: session
.alpn_protocol()
.map(|value| String::from_utf8_lossy(value).to_string()),
})
}
fn parse_target(target: &str, sni: Option<&str>) -> Result<(SocketAddr, ServerName), TlsError> {
let (host, port) = split_host_port(target)?;
let addr = resolve_addr(&host, port)?;
let server_name = if let Some(sni) = sni {
ServerName::try_from(sni).map_err(|_| TlsError::InvalidSni(sni.to_string()))?
} else if let Ok(ip) = host.parse::<IpAddr>() {
ServerName::IpAddress(ip)
} else {
ServerName::try_from(host.as_str())
.map_err(|_| TlsError::InvalidSni(host.to_string()))?
};
Ok((addr, server_name))
}
fn split_host_port(value: &str) -> Result<(String, u16), TlsError> {
if let Some(stripped) = value.strip_prefix('[') {
if let Some(end) = stripped.find(']') {
let host = &stripped[..end];
let rest = &stripped[end + 1..];
let port = rest
.strip_prefix(':')
.ok_or_else(|| TlsError::InvalidTarget(value.to_string()))?;
let port = port
.parse::<u16>()
.map_err(|_| TlsError::InvalidTarget(value.to_string()))?;
return Ok((host.to_string(), port));
}
}
let mut parts = value.rsplitn(2, ':');
let port = parts
.next()
.ok_or_else(|| TlsError::InvalidTarget(value.to_string()))?;
let host = parts
.next()
.ok_or_else(|| TlsError::InvalidTarget(value.to_string()))?;
if host.contains(':') {
return Err(TlsError::InvalidTarget(value.to_string()));
}
let port = port
.parse::<u16>()
.map_err(|_| TlsError::InvalidTarget(value.to_string()))?;
Ok((host.to_string(), port))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn split_host_port_ipv4() {
let (host, port) = split_host_port("example.com:443").unwrap();
assert_eq!(host, "example.com");
assert_eq!(port, 443);
}
#[test]
fn split_host_port_ipv6() {
let (host, port) = split_host_port("[2001:db8::1]:443").unwrap();
assert_eq!(host, "2001:db8::1");
assert_eq!(port, 443);
}
}
fn resolve_addr(host: &str, port: u16) -> Result<SocketAddr, TlsError> {
if let Ok(ip) = host.parse::<IpAddr>() {
return Ok(SocketAddr::new(ip, port));
}
let addr = std::net::ToSocketAddrs::to_socket_addrs(&(host, port))
.map_err(|err| TlsError::Io(err.to_string()))?
.next()
.ok_or_else(|| TlsError::InvalidTarget(host.to_string()))?;
Ok(addr)
}
fn build_connector(insecure: bool, alpn: &[String]) -> Result<TlsConnector, TlsError> {
let mut config = if insecure {
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth()
} else {
let mut roots = RootCertStore::empty();
let store = rustls_native_certs::load_native_certs()
.map_err(|err| TlsError::Io(err.to_string()))?;
for cert in store {
roots
.add(&Certificate(cert.0))
.map_err(|err| TlsError::Tls(err.to_string()))?;
}
ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth()
};
if !alpn.is_empty() {
config.alpn_protocols = alpn.iter().map(|p| p.as_bytes().to_vec()).collect();
}
Ok(TlsConnector::from(Arc::new(config)))
}
async fn connect(
addr: SocketAddr,
connector: TlsConnector,
server_name: ServerName,
timeout_ms: u64,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, TlsError> {
let tcp = timeout(Duration::from_millis(timeout_ms), TcpStream::connect(addr))
.await
.map_err(|_| TlsError::Timeout)?
.map_err(|err| TlsError::Io(err.to_string()))?;
let stream = timeout(
Duration::from_millis(timeout_ms),
connector.connect(server_name, tcp),
)
.await
.map_err(|_| TlsError::Timeout)?
.map_err(|err| TlsError::Tls(err.to_string()))?;
Ok(stream)
}
fn extract_cert_chain(certs: Option<&[Certificate]>) -> Result<Vec<TlsCertSummary>, TlsError> {
let mut results = Vec::new();
if let Some(certs) = certs {
for cert in certs {
let summary = parse_cert(&cert.0)?;
results.push(summary);
}
}
Ok(results)
}
fn parse_cert(der: &[u8]) -> Result<TlsCertSummary, TlsError> {
let (_, cert) =
X509Certificate::from_der(der).map_err(|err| TlsError::Parse(err.to_string()))?;
Ok(TlsCertSummary {
subject: cert.subject().to_string(),
issuer: cert.issuer().to_string(),
not_before: cert.validity().not_before.to_string(),
not_after: cert.validity().not_after.to_string(),
san: extract_san(&cert),
})
}
fn extract_san(cert: &X509Certificate<'_>) -> Vec<String> {
let mut result = Vec::new();
if let Ok(Some(ext)) = cert.subject_alternative_name() {
for name in ext.value.general_names.iter() {
result.push(name.to_string());
}
}
result
}
struct NoVerifier;
impl rustls::client::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp: &[u8],
_now: SystemTime,
) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
Ok(rustls::client::ServerCertVerified::assertion())
}
}